Grad values are None when tensors undergo shallow copy

Hi. I have a model who’s forward method performs a shallow copy of the tensors into a dictionary before returning like so -

def forward(self, input):
    block0 = self.block0(input)
    block1 = self.block1(block0)
    self.end_points = {}
    self.end_points['block0'] = (block0, 0)
    self.end_points['block1'] = (block1, 1)
    return block0 

Where self.block0 and self.block1 are nn.Conv2d layers followed by batch norm and leaky relu

If now I do -

output = model(input)
loss = output.mean()
loss.backward()
print (model.block0.conv.bias.grad) #block0 is an nn.Module with contains a class attribute conv which is nn.Conv2d

The grad value is None. There is a similar outcome if I return just return self.end_points dict.

One the other hand with the following forward function -

def forward(self, input):
    block0 = self.block0(input)
    block1 = self.block1(block0)
    self.end_points = {}
    self.end_points['block0'] = (block0, 0)
    return block0
    self.end_points['block1'] = (block1, 1) 

The grad attribute of model.block0 gets accumulated with the correct gradient.

I have this problem when I wrap the module in nn.DataParallel only. I’m using the following workaround since I have some custom functions.

class MyDataParallel(torch.nn.DataParallel):
    """
    Allow nn.DataParallel to call model's attributes.
    """
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

I am not able to understand why this is the case. Please help! Thank you in advance.

Could you post the definition of self.block0 and self.block1?
Based on your description, I would assume you’ve defined them as nn.Sequential modules, but then you would get an error calling model.block0.grad.

My apologies. I did not mention earlier that I encounter this problem only when I wrap my module inside nn.DataParallel. I have updated the description above.

As to your question both of those blocks are nn.Modules which have an instance of nn.Conv2d, nn.LeakyRelu and nn.BatchNorm2d which are called on the input in the forward method in that order.

The problem is fixed by not setting end_points as a class variable and returning the entire dict. I suspect the problem is along the lines of https://github.com/pytorch/pytorch/issues/16532. Not sure though where the tensor is going out of scope and triggering a recursive deletion of the rest of the graph though.

1 Like