DataLoader batch parameter

Each sample in a batch of data is an array. For a given batch, I only want to get a single index of the array.

Essentially, I want to go from N, K, C, H, W to N, C, H W by randomly sampling a value between [0, K] for each batch.

How do I accomplish this using DataLoader? I think it’s either a collate_fn or a worker_init_fn.

I need this to run before getitem ideally.

You could use the default_collate as a starter and just add your random slicing into your custom collate function.
Here is a small example for your use case:

from torch._six import container_abcs

class MyDataset(Dataset):
    def __init__(self, data): = data = torch.zeros(data.size(0))
    def __getitem__(self, index):
        x =[index]
        y =[index]
        return x, y
    def __len__(self):
        return len(

def my_collate(batch):
    if isinstance(batch[0], torch.Tensor):
        if batch[0].dim() == 4:
            batch = [b[torch.randint(0, K, (1,))].squeeze() for b in batch]
        return torch.stack(batch, 0)
    if isinstance(batch[0], container_abcs.Sequence):
        transposed = zip(*batch)
        return [my_collate(samples) for samples in transposed]

N, K, C, H, W = 10, 5, 3, 4, 4
data = torch.randn(N, K, C, H, W)
dataset = MyDataset(data)
loader = DataLoader(

for x, y in loader:

Note that I just implemented the collate function for data with 4 dimensions and a scalar target.
If your target has the same dimensions as your data, this won’t work.
Let me know, if that’s the case.

1 Like

Ah great, this just about works. Thanks!

1 Like

@ptrblck, turns out that doesn’t quite work.

Right now, my getitem fn does a lot of useless I/O and preprocessing. It gets the N K C H W tensor which the collate function you wrote will reduce to N C H W. (Shown on the left in my diagram)

What I want is shown in the diagram on the right. getitem returns a 1 C H W tensor (no useless I/O and preprocessing). The 1 here is got by sampling an index from 0 to K, and it has to be the same across the batch. (I also need to know which index it is…)

Does that make sense? Any ideas on how to proceed?

Maybe this is too specific a question to my usecase…

Thanks for your help!

Let me clarify your use case a bit.
You would like to use the same randomly sampled k for the complete batch of samples.
The next batch would therefore get a newly sampled k and use it, right?

Since your data loading is expensive, you want to sample k before __getitem__ is called, so that you can speed up the loading somehow?

Yes exactly! At the end of the process, I get one batch, and the sampled value k.

I achieved this by delaying I/O and preprocessing until the collate function, but it’s not very flexible.

at the moment, I do:

def collate_data(batch):
    k = np.random.randint(0, 10)
    batch = [sample[k] for sample in batch]
    data = load_batch(batch)
    augmented_data  = augment(data)
    return augmented_data, k

This seems a bit klunky, but it works. Would be nice if I could switch off augmentation using a flag. Maybe I should use a lambda function?

something like:

DataLoader(..., collate_fn=lambda batch: collate_fn(batch, augmentFlag))

I’d be interested to know if there are other ways of achieving this.


lambda functions is how I solved this. It works, and it’s clunky somewhat, but ah well.

def collate_data(batch, augmentFlag):
    k = np.random.randint(0, 10)
    batch = [sample[k] for sample in batch]
    data = load_batch(batch)
    if augmentFlag:
              data = augment(data)
    return data, k

The lambda functions would be:

val_dataset = DataLoader(..., collate_fn=lambda batch: collate_data(batch, augmentFlag=False))

train_dataset = DataLoader(..., collate_fn=lambda batch: collate_data(batch, augmentFlag=True))