IndexError: index_select(): Index is supposed to be a vector

for batch_id, (data, target) in enumerate(tqdm(train_loader)):
        print('Entered for loop')
        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

The line which contains the index_select function gives this error and I am not able to find a solution to it anywhere. The target variable on printing looks like this:


How do I convert the target variable into a vector? Isn’t it already a vector?

@raoashish10 u got any solution?

Yes I did! I think the index attribute expects a tensor of all the classes for example if I have 10 classes, then it expects a tensor([0,1,2,3,4,5,6,7,8,9]). So I declared a tensor in this manner and it worked for me. Let me know if it works for you too.

Thank you very much.

I want to update my answer because the previous solution was a little misleading. You can use the flatten() function. In my case, I used it as target = torch.sparse.torch.eye(10).index_select(dim=0, index=target.long().flatten()) The problem with the previous solution is that it erases the error but will decrease your accuracy considerably.