Getting the Data Points Classified Correctly

If I wanted to observe which data-points are being classified correctly at every epoch, is there a simple way to do this? I’m thinking I would get torch.max(net(inputs,1) but how would I obtain the indices of these data-points. Given that in:

max_vals, max_indices = torch.max(net(inputs,1)

max_indices actually returns the index of the correctly classified label and not the the index of the datapoint, is there a better way to do this?


What do you mean by the “index of the datapoint” ? The index in your Dataset?

Hello, thank you for your response. Yes exactly! :slight_smile:

One approach will be for your dataset to return the index in the __getitem__ function and then you can have access to it.

# In dataset definition
  def __getitem__(self, idx):
    return foo, bar, idx # To replace just foo, bar

# In train loop
  for foos, bars, idxs in dataloader:
    # Now you have access to the indices in the dataset of each sample.

Another approach is to get a dataloader that does not shuffle the samples, and then you know that the indices are given in increasing order from 0 to dataset_size by groups of batch_size.