for batch_id, (data, target) in enumerate(tqdm(train_loader)):
print('Entered for loop')
target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
data, target = Variable(data), Variable(target)
The line which contains the index_select function gives this error and I am not able to find a solution to it anywhere. The target variable on printing looks like this:
How do I convert the target variable into a vector? Isn’t it already a vector?
@raoashish10 u got any solution?
Yes I did! I think the index attribute expects a tensor of all the classes for example if I have 10 classes, then it expects a tensor([0,1,2,3,4,5,6,7,8,9]). So I declared a tensor in this manner and it worked for me. Let me know if it works for you too.
I want to update my answer because the previous solution was a little misleading. You can use the flatten() function. In my case, I used it as
target = torch.sparse.torch.eye(10).index_select(dim=0, index=target.long().flatten()) The problem with the previous solution is that it erases the error but will decrease your accuracy considerably.