This fails
import torch
def test1():
layer = nn.Linear(100, 10)
x = 5 - torch.sum(layer(torch.ones(100)))
x.backward()
layer.weight.data = layer.weight.data[:, :90]
layer.weight.grad.data = layer.weight.grad.data[:, :90]
x = 5 - torch.sum(layer(torch.ones(90)))
x.backward()
test1()
with error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-3-bb36a010bd86> in <cell line: 10>()
8 x = 5 - torch.sum(layer(torch.ones(90)))
9 x.backward()
---> 10 test1()
11 # and this works as well
12
2 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
252 tensors,
253 grad_tensors_,
RuntimeError: Function TBackward0 returned an invalid gradient at index 0 - got [10, 90] but expected shape compatible with [10, 100]
This works
import torch
def test2():
layer = torch.nn.Linear(100, 10)
x = 5 - torch.sum(layer(torch.ones(100)))
x.backward()
del x #main change
layer.weight.data = layer.weight.data[:, :90]
layer.weight.grad.data = layer.weight.grad.data[:, :90]
x = 5 - torch.sum(layer(torch.ones(90)))
x.backward()
test2()
and this works as well
import torch
def test3():
layer = torch.nn.Linear(100, 10)
x = 5 - torch.sum(layer(torch.ones(100)))
x.backward()
layer.weight.data = layer.weight.data[:, :90]
layer.weight.grad.data = layer.weight.grad.data[:, :90]
layer.weight = torch.nn.Parameter(layer.weight) #main change
x = 5 - torch.sum(layer(torch.ones(90)))
x.backward()
test3()
I encountered this when trying to implement a paper on model pruning. I believe this has something to do with the autograd graph, but I have am not sure what exactly is going on. Any explanation these almost identical code snippets work/fail would be very appreciated. Thanks in advance.