Spatial Transformer
2021-07-22 13:43:30 # 机器学习

先讲仿射变换和双线性插值怎么做

就是把原本的坐标矩阵做一个线性变换,得到新的坐标矩阵,然后由于这个坐标矩阵里面有的不是整数,因此做双线性插值,对output的矩阵上每一个像素P(x,y)来说,它去变换出来的那小数矩阵上,根据公式计算它应该有的像素。

Bi-linear Interpolation

image-20210722165809211

然后再看STN

由此可见三个部分:

  • Localisation net :参数预测
  • Gridgenerator :坐标映射
  • ***Sampler:***像素采集

image-20210722163222431

听起来和玄乎,第一个参数预测其实就是前面说的仿射变换,对于图片来说,一个(2,3)的矩阵就可以完成各种变换(旋转,平移,放大,缩小)

第二个坐标映射就是 拿原本的坐标矩阵乘上仿射变换的矩阵,得到一个新的坐标矩阵 $X^S$

由于新的坐标矩阵可能有小数,要做处理,这里有两种方式,一种是向前映射,就是将新坐标对应的点(浮点数坐标的点)根据其周围四个点的位置,把自己的值按照权重分给周围四个点,最后结果是原本映射的像素点按照一定权重叠加出来的矩阵。

还有一个是向后映射,就是说我们最终output出来的这个新矩阵,它的shape我们是知道了,那根据它的每个点坐标,根据逆映射,可以找到对应原图像上的一个坐标(但是可能存在浮点数),然后用这个浮点数坐标周围的四个点来进行插值,就能计算出一个个的像素。(我本来以为这个地方要对变换矩阵求逆,但是其实不用,在初始定义的时候,变换矩阵就没有卡死说是对谁的变换,而且参数也是随着梯度下降逐渐更新的,所以一开始做坐标映射的时候就向后映就ok了)

反向传播我没有推,但是据说这样的方法是可以正确回传的,放上公式把。

img

参考

还有Pytorch官方的tutorial Spatial Transformer Networks Tutorial — PyTorch Tutorials 1.9.0+cu102 documentation

然后我在李宏毅2021的HW3里试了一下这个,但是由于HW3里本身就做过Data-aug,所以效果可能不是太好,而且具体超参数我也懒得调。

官方给的代码里面还是有很多东西的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

# Spatial transformer localization-network 这里就是参数预测部分,可以看到其实接了两层卷积池化,这里的参数是决定后面变换矩阵参数的。
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7), #卷积层的第一个参数是in_channel的Dimension,第二个是out_channel的Dimension
#Kernel_size是卷积核的大小
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)

# Regressor for the 3 * 2 affine matrix 这个矩阵对参数矩阵做变换,然后得到(N,2,3)的变换矩阵
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)

# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

# Spatial transformer network forward function
def stn(self, x):
#这里就是Spatial transformer的部分,先根据输入前传得到具体的theta矩阵
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
#这里的两个函数相当于专为这个STN网络而生,第一个是根据输入形状和theta矩阵得到变化后矩阵的Index矩阵(就刚刚说的那个里面有小数的矩阵)
grid = F.affine_grid(theta, x.size())
#这里把Index矩阵和输入一起输入,进行采样,然后出来的就是变化后的输入了。
x = F.grid_sample(x, grid)

return x

def forward(self, x):
# transform the input
x = self.stn(x)

# Perform the usual forward pass
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)


model = Net().to(device)