Get index from DataLoader, which loads from sparse tensor


Hi there, I am trying to get index from dataloader, which loads from torch.sparse.FloatTensor.

The code is like:

sparseTensor = torch.sparse.FloatTensor(i, v)
train_data = DataLoader(sparseTensor)

for input in train_data:
  # Need the index of the input here.

The input is a sparse matrix, and I need the row number for each row trained.

Thanks for your time!