Dataset Multiple Samples per getitem Call

Assuming the returned tensor contains the desired values, you could flatten it in the DataLoader loop via:


for batch in dls:
    sample = batch['sample']
    sample = sample.view(-1, 5)
    print(sample.shape) # should print [4, 5] now

Using a custom collate_fn would probably also work, but I think the view operation might be easier.

1 Like