官方链接
1:repeat是如何实现的?
2:cut是如何实现的?
3:整个block是如何实现的?
class FullyAttentionalBlock(nn.Module):
def __init__(self, plane, norm_layer=SyncBatchNorm):
super(FullyAttentionalBlock, self).__init__()
self.conv1 = nn.Linear(plane, plane) # 改变最后一个维度
self.conv2 = nn.Linear(plane, plane)
self.conv = nn.Sequential(nn.Conv2d(plane, plane, 3, stride=1, padding=1, bias=False),
norm_layer(plane),
nn.ReLU())
self.softmax = nn.Softmax(dim=-1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch_size, _, height, width = x.size()
feat_h = x.permute(0, 3, 1, 2).contiguous().view(batch_size * width, -1, height) # [b*w, c, h]
feat_w = x.permute(0, 2, 1, 3).contiguous().view(batch_size * height, -1, width) # [b*h, c, w]
# (b,c,w)--->(b,h,c)
encode_h = self.conv1(F.avg_pool2d(x, [1, width]).view(batch_size, -1, height).permute(0, 2, 1).contiguous())
# (b,c,w)--->(b,w,c)
encode_w = self.conv2(F.avg_pool2d(x, [height, 1]).view(batch_size, -1, width).permute(0, 2, 1).contiguous())
# (b*w,c,h) * (b*w,h,c) = (b*w,c,c)
energy_h = torch.matmul(feat_h, encode_h.repeat(width, 1, 1))
# (b*h,c,w) * (b*h,w,c) = (b*w,c,c)
energy_w = torch.matmul(feat_w, encode_w.repeat(height, 1, 1))
full_relation_h = self.softmax(energy_h) # [b*w, c, c]
full_relation_w = self.softmax(energy_w) # [b*h, c, c]
# [b*w, c, c]*[b*w, c, h] = [b*w, c, h]---->[b,w,c,h]--->[b,c,h,w]
full_aug_h = torch.bmm(full_relation_h, feat_h).view(batch_size, width, -1, height).permute(0, 2, 3, 1)
full_aug_w = torch.bmm(full_relation_w, feat_w).view(batch_size, height, -1, width).permute(0, 2, 1, 3)
out = self.gamma * (full_aug_h + full_aug_w) + x
out = self.conv(out)
return out
首先看一下construction里维度是如何变化的:
Fin(b,c,h,w)经过大小为【h x 1】和【1 x w】大小的卷积核后,维度变为【b,c,h】和【b,c,w】经过线性层唯独不变,再沿着h和w维度进行复制,维度变为【wb,h,c】,【hb,w,c】这里用一个demo演示repeat函数:
import torch.nn.functional as F
x = torch.rand(2,3,5,5)
batch_size, channel, height, width = x.size()
# (1,3,5)--->(1,15)
encode_h = F.avg_pool2d(x, [1, width]).view(batch_size, -1, height).permute(0, 2, 1).contiguous()
v = encode_h.repeat(width, 1, 1)
print(encode_h)
print(encode_h.shape) # torch.Size([1, 5, 3])
print(v)
print(v.shape) # torch.Size([10, 5, 3])
然后进行cut操作,然后和feat_h,feat_w进行矩阵相乘,然后维度变为【bw,c,c】和【bh,c,c】经过softmax维度不变,生成的相似度map与v相乘,维度变为【b*w, c, h】再reshape维度变为【b,w,c,h】—>【b,c,h,w】
对于问题一,二:
在代码中使用了两次repeat函数,encode_h.repeat(width, 1, 1)
,encode_w.repeat(height, 1, 1)
分别在w和h方向上进行复制,
在原文中:
维度c x 1x w和 c x h x1变成c x h x w就是在h和w维度进行复制,然后又在h和w维度进行cut,cut是为了形成slices。
在代码中并没有展现cut是如何操作的。
因为经过repeat之后就直接和K相乘了,而且也并没有表现merge部分。在图中A的维度为【(h+w),c,c】,在代码中直接repeat之后就和K相乘,生成了两个【bw,c,c】和【bh,c,c】,其中【bw,c,c】和【bh,c,c】相加,就生成了文中【(h+w),c,c】代码中是分离开来的。
# (b*w,c,h) * (b*w,h,c) = (b*w,c,c)
energy_h = torch.matmul(feat_h, encode_h.repeat(width, 1, 1))
# (b*h,c,w) * (b*h,w,c) = (b*h,c,c)
energy_w = torch.matmul(feat_w, encode_w.repeat(height, 1, 1))
最后生成的A,分别和V相乘(在这里K应该是等于V的),再相加。对应于文中的:
full_aug_h = torch.bmm(full_relation_h, feat_h).view(batch_size, width, -1, height).permute(0, 2, 3, 1)
full_aug_w = torch.bmm(full_relation_w, feat_w).view(batch_size, height, -1, width).permute(0, 2, 1, 3)
out = self.gamma * (full_aug_h + full_aug_w) + x
其余的细节参考官方代码。