How to correctly iterate dataset indices

I have this code for my weighted sampler:

loader_iter = iter(train_dataset)
    for _ in range(50):
        idx, (frames, label, file_name), _ = next(loader_iter)
        class_weight = class_weights[label]
        sample_weights[idx] = class_weight
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

but I am getting this error:
TypeError: 'int' object is not iterable

then I try to wrap the loader iter in range, like this:

idx, (frames, label, file_name), _ = range(next(loader_iter))

so I get an error TypeError: 'tuple' object cannot be interpreted as an integer

then I tried to wrap with enumerate:
idx, (frames, label, file_name) = enumerate(next(loader_iter)

but I am getting ValueError: too many values to unpack (expected 2)

Any ideas for me? Thank you in advance

The general logic looks correct and also works for me:

dataset = TensorDataset(torch.arange(10), torch.arange(1, 11))

dataset_iter = iter(dataset)
for _ in range(10):
    data, target = next(dataset_iter)

It’s unclear how your train_dataset is defined so you might need to debug its __getitem__ method and check which line of code raises the error.

1 Like