I encounter this situation too.
My training code is as below:
class testNet(nn.Module):
def __init__(self):
super(testNet, self).__init__()
self.layers = nn.ModuleList([])
self.posemb = nn.Linear(2, 40)
for i in range(3):
self.layers.append(nn.Linear(2, 40))
self.bias = None
def forward(self, x, idx):
if self.bias is None:
self.bias = self.posemb(x).reshape(-1, 20, 2)
pred = self.layers[idx](x)
pred = pred.reshape(-1, 20, 2) + self.bias
return pred
if __name__ == '__main__':
net = testNet()
for b in range(3):
x = torch.rand(4, 2)
label = torch.rand(4, 20, 2)
ll = []
for i in range(3):
pred = net(x, i)
x = pred[:, 0, :]
ll.append(torch.norm(pred - label, p=-1).mean())
loss = torch.stack(ll).mean()
loss.backward()
print('batch: %d | loss: %f' % (b, loss.item()))
In the first train loop, everything is ok, but in the second loop, it will give me this error.
I have checked for a long time, and finally find out the problem is in my network.
In my forward function, I will save a intermedium variable to save extra calculation.
In the second loop, this variable should to be calculate again, but my judgement condition skip the recalculate process, so the grad graph of this intermedium variable is being cleared.
After I change
if self.bias is None:
self.bias = self.posemb(x).reshape(-1, 20, 2)
to
if idx == 0:
self.bias = self.posemb(x).reshape(-1, 20, 2)
the problem is solved!