Invalid gradient shape after discarding filters during training

Hello,

I’m trying to remove some filters during training, however, the .backward() in the second iteration raises an error due to the size mismatch. The error is the following:

Traceback (most recent call last):
  File "fixingbackward.py", line 44, in <module>
    loss.backward() # Error here
  File "/home/user/data/envs/def/lib/python3.8/site-packages/torch/tensor.py", line 222, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/user/data/envs/def/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function CudnnConvolutionBackward returned an invalid gradient at index 1 - got [1, 7, 3, 3] but expected shape compatible with [1, 8, 3, 3]

The following code reproduces the error:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 8, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(8, 1, 3, padding=1)

    def forward(self, x): 
        x = self.conv1(x)
        x = self.conv2(x)
        return x

def loss_fun(x):
    return torch.sum(x)

# Initialization
model = Model().cuda()
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
im = torch.rand((1, 1, 256, 256)).cuda()

# Iteration 1
output = model(im)
loss = loss_fun(output)

opt.zero_grad()
loss.backward()
opt.step()

# Remove the first channel of the filters (and the grads)
model.conv1.weight.data = model.conv1.weight.data[1:] # size = (7, 1, 3, 3)
model.conv1.bias.data = model.conv1.bias.data[1:] # size = (7)
model.conv2.weight.data = model.conv2.weight.data[:, 1:] # size = (1, 7, 3, 3)
#### Alternatively, this seems to work, but it increases memory usage until it explodes
# model.conv1.weight = torch.nn.Parameter(model.conv1.weight.data[1:]) # (7, 1, 3, 3)
# model.conv1.bias = torch.nn.Parameter(model.conv1.bias.data[1:]) # (7)
# model.conv2.weight = torch.nn.Parameter(model.conv2.weight.data[:, 1:]) # (1, 7, 3, 3)

# Iteration 2
output = model(im)
loss = loss_fun(output)

opt.zero_grad()
loss.backward() # Error here
opt.step()

Following this post, I also tried 1) to update weight.grad.data similarly, and 2) to even initialize the optimizer right before Iteration 2. But these approaches didn’t work.

After debugging for a long time I’m stuck here since I don’t know where does the information about the shape comes from. In fact, I don’t understand why this problem arises. I would expect that when I do the output = model(im) autograd realizes that the shapes have changed already and should ignore whatever it knows from Iteration 1. It seems that it stores some information from the previous iteration but I have no clue what/where can be.

Thanks!

That’s not the case, since you are manipulating the underlying .data attributes, which Autograd doesn’t track, and which is thus not recommended.
Try to assign the new parameters directly and, if needed, wrap it into a with torch.no_grad() guard.

Edit: I’ve solved the problem, so I re-wrote my initial answer and what I initially thought it was the problem. Then I write the actual problem I had (which was embarrassingly simple).

I initially though that with torch.nn.Parameter I would be creating new parameters all the time, and because of this I run out of memory. In fact, the following code makes it run out of memory

with torch.no_grad():
    for mod in self.layer.modules():
        if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Conv3d    ):
            pp = mod.weight
            del mod.weight
            mod.weight = torch.nn.Parameter(pp.data)

whereas this doesn’t make the code run out of memory. Note: I have tested this on my actualy script, not on the toy example I gave in the initial post.

with torch.no_grad():
    for mod in self.layer.modules():
        if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Conv3d    ):
            pp = mod.weight
            del mod.weight
            mod.weight = pp

However, I examined what’s on memory (following this), and I noticed that the number of objects in memory (including torch.nn.Parameters) and the total size increase at first, but it becomes constant in both cases. How is it possible that despite the number of tensors, parameters, and their size are constant, the script runs out of memory?

Solution
In my code I kept track of the loss in this way:

tr_loss_iteration = loss(output, Y)
tr_loss += tr_loss_iteration # For knowing the avg loss in the current epoch.

Apparently, tr_loss was accumulating something, although I’m not sure what, because, as I’ve said, the number of tensors and parameters in memory was constant. Maybe it was keeping track of all “new” parameters? Anyway, I solved the memory-leaking problem by simply:

tr_loss += tr_loss_iteration.cpu().detach().numpy()

Thanks!