Hi all, first post here.
I am getting the error in the title when I am calling loss.backward:
criterion = nn.MSELoss()
…
train_epoch_loss = 0
train_epoch_acc = 0
model.train()
for X_train_batch, y_train_batch in train_loader:
X_train_batch, y_train_batch = X_train_batch.to(device), y_train_batch.to(device)
optimizer.zero_grad()
y_train_pred = model(X_train_batch).argmax(dim=1).float()
train_loss = criterion(y_train_pred, y_train_batch)
train_acc = multi_acc(y_train_pred, y_train_batch)
train_loss.backward()
optimizer.step()
train_epoch_loss += train_loss.item()
train_epoch_acc += train_acc.item()
Python give me this error:
RuntimeError Traceback (most recent call last)
in
137
138
→ 139 root = Tree(“root-”,train,
140 homo_threshold = THRESH)
in init(self, name, df, homo_threshold)
52 self.kid_rank = None
53
—> 54 self.model = make_nn(df,verbose=False) if self.kids == None else None
55
56 def Error(self,pred,y):
in make_nn(df, verbose)
149 train_acc = multi_acc(y_train_pred, y_train_batch)
150
→ 151 train_loss.backward()
152 optimizer.step()
153
D:\anaconda3\lib\site-packages\torch_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
253 create_graph=create_graph,
254 inputs=inputs)
→ 255 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
256
257 def register_hook(self, hook):
D:\anaconda3\lib\site-packages\torch\autograd_init_.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
145 retain_graph = create_graph
146
→ 147 Variable.execution_engine.run_backward(
148 tensors, grad_tensors, retain_graph, create_graph, inputs,
149 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Up to now, I suspect that calling argmax is the culprit, but I am not 100% certain of this neither of why or how. I had a previous version where I used nn.crossentropy and it worked fine (with minor tweaks) but once I add argmax to the mix it seems to throw this error.
A common refrain in other threads suggest I somehow messed with the computation graph by mistake, is there a way to visualize it so I can troubleshoot it myself?
This is part of broader tests to create my own loss function.
I know I am banging rocks together, so what am I doing wrong?
PS: I tried to cut down code to the bare minimum as to make the post more legible, as IIRC I don’t think I changed much more than that particular section.
In case I cut too much, this is the full source code of the script I am writing: