How to retrieve the sample indices of a mini-batch

You could adapt this code:

from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self):
        self.cifar10 = datasets.CIFAR10(root='YOUR_PATH',
                                        download=False,
                                        train=True,
                                        transform=transforms.ToTensor())
        
    def __getitem__(self, index):
        data, target = self.cifar10[index]
        
        # Your transformations here (or set it in CIFAR10)
        
        return data, target, index

    def __len__(self):
        return len(self.cifar10)

dataset = MyDataset()
loader = DataLoader(dataset,
                    batch_size=1,
                    shuffle=True,
                    num_workers=1)

for batch_idx, (data, target, idx) in enumerate(loader):
    print('Batch idx {}, dataset index {}'.format(
        batch_idx, idx))
17 Likes