The first conv2d layer does not calculate gradient at all

Long story short, I cannot get my correct gradient (at least in theory) when I was trying to do backpropagation.

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.input_conv = torch.nn.Conv2d(3, 2, 1)
        self.feature_extractor =  models.resnet50(pretrained=True)#resnext50_32x4d  resnet50 resnext101_32x8d
        num_ftrs = self.feature_extractor.fc.in_features  # for resnet
        self.output_fc = nn.Linear(num_ftrs, num_classes, bias = False)
    def forward(self, x):
        x = torch.cat((x[:, 0, :, :].unsqueeze(1), self.input_conv(x[:, 1:])), 1)
        x = self.feature_extractor(x)
        y = self.output_fc(x)
        return y 

Here’s my network, and I also had setup hook function such that I could store the gradients. (The code is copied from github repo called pytorch-cnn-visualization)

    def hook_layers(self):
        def hook_function(module, grad_in, grad_out):
            self.gradients.append(grad_in[0])
            self.gradients.append(grad_out[0])
            print(grad_in[0].shape, grad_out[0].shape)
        children = list(self.model.children())
        first_layer = children[0]
        first_layer.register_backward_hook(hook_function)

So in this network, the first_layer was set to torch.nn.Conv2d(3, 2, 1) which make sense. However, when I was trying to look at the gradient stored by the hook function, the shape of both grad_in and grad_out are [1, 2, 224, 244]. But it one of them should be [1, 3, 224, 224]. When I print them out and compared them it seems that their content is exactly the same. Is that a bug or I did something wrong?

It’s a known issue as described in the Warning in the docs:

The current implementation will not have the presented behavior for complex Module that perform many operations. In some failure cases, grad_input and grad_output will only contain the gradients for a subset of the inputs and outputs. For such Module, you should use torch.Tensor.register_hook() directly on a specific input or output to get the required gradients.

1 Like