How to scan a dataset and trace back the model outputs to the data points?


(Milad Pourrahmani) #1

I would like to scan a dataset with a model and collect the output (say, 10 class probabilities for instance). After scanning, I want to know which output corresponds to which data point. The best way I can think of is to make my __gititem__ (under DatasetMaker class) to return (index, data, label) (not just (data, label)). This way, it doesn’t matter if my datasetLoader shuffles or not (right?). Then, I’ll collect model outputs in a dataframe.

Is there a better way to scan and trace datasets?


(Jeong TaeYeong) #2

Right. The index argument given to the __getitem__ method is not affected by shuffle option of dataloader.
For example, if shuffled order is [2, 0, 1], then __getitem(2)__, __getitem(0)__, __getitem(1)__ will be called in order.

If you turn off the shuffle option, it’s guranteed that order of outputs is matched to the order of input.
So you can do something like this.

cumidx, results = 0, {}
for i, data in enumerate(loader):
    outputs = model(data)
    for j in range(outputs.size(0)):
        results[j + cumidx] = outputs[j]
    cumidx += outputs.size(0)

which doesn’t require Dataset class to return index of corresponding data.