I guess you might be using nn.CrossEntropyLoss as the loss_fct?
If so, note that this criterion accepts model outputs in the shape [batch_size, nb_classes, *] and targets as LongTensors in the shape [batch_size, *] containing class indices in the range [0, nb_classes-1] as well as FloatTensors in the same shape as the model output containing probabilities.
since you are detaching the predictions tensor from the computation graph in the argmax operation, which is not differentiable, and it seems you are trying to fix this by re-wrapping the tensor, which will not “re-attach” it to the computation graph.
I remove this line predictions = predictions.argmax(dim=-1)
and got this error:
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
IndexError: Target 17 is out of bounds.
the shapes as the following:
batch_size = 4
seq_length = 128
n_classes = 17
Before logits shape : torch.Size([4, 128, 17])
After predictions shape : torch.Size([512, 17])
Before labels shape : torch.Size([4, 128])
After labels shape : torch.Size([512])
So the input of the CrossEntropyLoss:
predictions shape : torch.Size([512, 17]) : (batch_size * seq_length, n_classes)
labels shape : torch.Size([512]): (batch_size * seq_length,)
I used the same training data with the default Trainer and it worked without causing this error. I think the default loss is CrossEntropyLoss, so it supposes not to have the target out of bounds.
@ptrblck can we do the same functionality of predictions = predictions.argmax(dim=-1) without detaching the predictions tensor from the computation graph?
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
if I replace argmax with torch.max it shouldn’t detach the tensor from the graph. _, predictions = torch.max(logits, dim=2) or even this _, predictions = torch.max(torch.tensor(logits, requires_grad=True), dim=2)
No, argmax is not differentiable as already mentioned. torch.max will return the max values (which are still attached to the computation graph) and the argmax (which will be detached as it’s not differentiable).