How to delete every grad after training?

Is there a way to delete all .grad attributes after training?
I am currently implementing some pruning techniques that reduce the dimensions of the weight tensors in convolution layers. For this to work I need to set all gradients to None or delete them since they don’t match the size of the filters anymore. However even after setting all of the weight.grad to None I still receive an error:

File "/home/bullseye/Documents/git/Master-Thesis-Pruning-for-Object-Detection/docker-omgeving/Code/trainengine.py", line 39, in process_batch
    loss.backward()
  File "/home/bullseye/.local/lib/python3.6/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/bullseye/.local/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function CudnnConvolutionBackward returned an invalid gradient at index 1 - got [1024, 527, 3, 3] but expected shape compatible with [1024, 767, 3, 3]

So I am either missing some gradients or the ones I’m setting to None don’t act like I thought they would.

Hi,

This error is not from the accumulation but the fact that the convolution returns a gradient for an input that is not the same size as the input itself.
I guess you’re playing trick changing some sizes? in between the forward and backward?

Yes, the dimensions of the input tensors are being changed in between training rounds. So the pruning sequence goes a bit like this: find filters that don’t contribute much -> delete them -> retrain for lost accuracy -> repeat first step.
The problem is indeed that the dimensions of the gradient isn’t changed, that’s why I wanted them to be deleted (or set to None) after each time the network is retrained.

But after you delete some of the filters, do you run a new forward with the smaller filters before running the backward? The error here is a mismatch between the size during forward and during backward.

@albanD Yes, the training set is split in multiple batches, each batch the following code is run:

data = data.to(self.device)
out = self.network(data)
loss = self.loss(out, target) / self.batch_subdivisions
loss.backward()

We retrain for 1 epoch (around 166 batches).

@albanD Apparently pushing the network to cpu and then back to gpu fixes the problem. I’m assuming this is because the gradients aren’t kept when changing device?

That is unexpected… How do you modify your parameters with the new, pruned, ones?

@albanD I prune like this:

m.weight.data = torch.cat((m.weight.data[:filter[1]], m.weight.data[filter[1]+1:]))

I also prune the biases and dependent input channels in the same manner.
I then lower the out_channels parameter in the pruned layer. After pruning all layers I attempted to set all gradients to None, although i’m still unsure if this works.

Hooo one more proof for https://github.com/pytorch/pytorch/issues/30987 that .data is the source of all evil !

The short answer is: do not use .data :slight_smile:
The longer answer is:

m.weight= nn.Parameter(torch.cat((m.weight[:filter[1]], m.weight[filter[1]+1:])))

And make sure to give these new weights to your optimizer again after changing them !

3 Likes

I will test this tomorrow. Thank you very much for your help so far.

@albanD Hi again, I tested your proposed solution and it works! Thank you very much for your help.
Would you be able to explain why it has to be done this way though?

Changing .data does many things. In your particular case, it changes the underlying Tensor as you wanted. But it does not reset the Node that is reponsible for accumulating gradients. So the accumulation won’t work even though you delete the .grad field.