Thanks for the update.
The error is caused, since you are explicitly casting the targets to torch.LongTensor
, which is a CPU type.
Instead of using .type(torch.LongTensor)
, call label1.squeeze().long()
to keep the tensor on the GPU.
1 Like