hello,
I have a network for attention prediction, which outputs a saliency map for a given image ( chess game).
But I just figured out that my output is rotated an flipped along x axis (salient pieces not in the correct place).
Here the network I created:
class VGGNet(nn.Module):
def init(self):
“”“Select conv1_1 ~ conv5_1 activation maps.”“”
super(VGGNet, self).init()
self.select = [15,22,29]
self.features = torch.nn.Sequential(
# conv1
torch.nn.Conv2d(in_channels=3,out_channels=64, kernel_size=3,padding=35),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2),
# conv2
torch.nn.Conv2d(64, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(128, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2),
# conv3
torch.nn.Conv2d(128, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(256, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(256, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2),
# conv4
torch.nn.Conv2d(256, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2),
# conv5
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2)
)
self.deconv1 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(256, 128, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(128, 64, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(64, 1, 3, padding=0,stride=1),
torch.nn.ReLU(),
)
self.deconv2_512_256 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(512, 256, 4, stride=2),
torch.nn.ReLU()
)
self.deconv_512_512 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(512, 512, 4, stride=2),
torch.nn.ReLU()
)
self.final_attention_pred = torch.nn.Sequential(
torch.nn.ConvTranspose2d(9, 1, 3, stride=1,padding=1)
)
self._initialize_weights()def _initialize_weights(self): # initializing weights using ImageNet-trained model from PyTorch for i, layer in enumerate(models.vgg16(pretrained=True).features): if isinstance(layer, torch.nn.Conv2d): self.features[i].weight.data = layer.weight.data self.features[i].bias.data = layer.bias.data for m in self.deconv1: if isinstance(m, torch.nn.ConvTranspose2d): m.weight.data.normal_(0.0, 0.01) m.bias.data.zero_() for m in self.deconv1: if isinstance(m, torch.nn.ConvTranspose2d): m.weight.data.normal_(0.0, 0.01) m.bias.data.zero_() for m in self.deconv2_512_256: if isinstance(m, torch.nn.ConvTranspose2d): m.weight.data.normal_(0.0, 0.01) m.bias.data.zero_() for m in self.deconv_512_512: if isinstance(m, torch.nn.ConvTranspose2d): m.weight.data.normal_(0.0, 0.01) m.bias.data.zero_() for m in self.final_attention_pred: if isinstance(m, torch.nn.ConvTranspose2d): m.weight.data.normal_(0.0, 0.1) m.bias.data.zero_() def forward(self, x): ##return list of feature map at different size features = [] for i, layer in enumerate(self.features): if(i in self.select ): x = layer(x) features.append(x) else: x = layer(x) saliency = [] m = nn.Sigmoid() m.register_backward_hook(printgradnorm) attentionmap1 = self.deconv1(features[0])[:, :, 36:260, 36:260] attentionmap1 = [attentionmap1,attentionmap1,attentionmap1] attentionmap1 = torch.cat(attentionmap1,1) attentionmap2 = self.deconv1(self.deconv2_512_256(features[1]))[:, :, 42:266, 42:266] attentionmap2 = [attentionmap2,attentionmap2,attentionmap2] attentionmap2 = torch.cat(attentionmap2,1) attentionmap3 = self.deconv1(self.deconv2_512_256(self.deconv_512_512(features[2])))[:, :, 54:278, 54:278] attentionmap3 = [attentionmap3,attentionmap3,attentionmap3] attentionmap3 = torch.cat(attentionmap3,1) saliency.append(m(attentionmap1)) saliency.append(m(attentionmap2)) saliency.append(m(attentionmap3)) output_data = torch.cat(saliency,1) output = m(self.final_attention_pred(output_data)) return output
the output is rotated 90 degres left and fliped along x axis (bottom is up and up in the bottom)
Any idea what part does that ?