I guess you might be using nn.CrossEntropyLoss as the loss_fct?
If so, note that this criterion accepts model outputs in the shape [batch_size, nb_classes, *] and targets as LongTensors in the shape [batch_size, *] containing class indices in the range [0, nb_classes-1] as well as FloatTensors in the same shape as the model output containing probabilities.
since you are detaching the predictions tensor from the computation graph in the argmax operation, which is not differentiable, and it seems you are trying to fix this by re-wrapping the tensor, which will not “re-attach” it to the computation graph.
I remove this line predictions = predictions.argmax(dim=-1)
and got this error:
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
IndexError: Target 17 is out of bounds.
the shapes as the following:
batch_size = 4
seq_length = 128
n_classes = 17
Before logits shape : torch.Size([4, 128, 17])
After predictions shape : torch.Size([512, 17])
Before labels shape : torch.Size([4, 128])
After labels shape : torch.Size()
So the input of the CrossEntropyLoss:
predictions shape : torch.Size([512, 17]) : (batch_size * seq_length, n_classes)
labels shape : torch.Size(): (batch_size * seq_length,)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
if I replace argmax with torch.max it shouldn’t detach the tensor from the graph. _, predictions = torch.max(logits, dim=2) or even this _, predictions = torch.max(torch.tensor(logits, requires_grad=True), dim=2)
No, argmax is not differentiable as already mentioned. torch.max will return the max values (which are still attached to the computation graph) and the argmax (which will be detached as it’s not differentiable).