Hello everyone,
I get this error when calling lossfunc(model(x), y).backward(). I can’t figure out where the problem is.
I trained the following network on CIFAR10:
from torchvision.models import vgg11
pt_vgg11 = vgg11(pretrained=False)
#print(pt_vgg11)
def seqVGG():
features = nn.Sequential()
features.add_module('features', nn.Sequential(*list(pt_vgg11.children())[:-1]))
features.add_module('flat', nn.Flatten())
features.add_module('classifier', nn.Sequential(*list(pt_vgg11.children())[2]))
return features
vgg11_seq = seqVGG()
(the backpack packages expects models to be sequences of PyTorch NN modules)
Now I want to calculate the diagonalHessians (for uncertainty estimates), for one date of the trainloader using the following code
(copied from BackPACK package website backpack.pt):
for batch_idx, (x, y) in enumerate(trainloader):
X, y = x, y
if batch_idx > 0:
break
model = extend(vgg11_seq)
lossfunc = extend(CrossEntropyLoss())
loss = lossfunc(model(X), y)
with backpack(DiagHessian()):
loss.backward()
for param in model.parameters():
print(param.grad)
print(param.diag_ggn_exact)
But I get the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-38-2fa093cbf3ef> in <module>()
9
10 with backpack(DiagHessian()):
---> 11 loss.backward()
12
13 for param in model.parameters():
~/anaconda3/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
196 products. Defaults to ``False``.
197 """
--> 198 torch.autograd.backward(self, gradient, retain_graph, create_graph)
199
200 def register_hook(self, hook):
~/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
98 Variable._execution_engine.run_backward(
99 tensors, grad_tensors, retain_graph, create_graph,
--> 100 allow_unreachable=True) # allow_unreachable flag
101
102
RuntimeError: 'NoneType' object is not subscriptable
can anyone help me with that? if you need more information just tell me!