Conditional bach generation

Hi, everyone! I’m a bit stuck with a quite simple problem :weary: Say we have a dataset of vectors, where vectors have either 1 or 0 as the last element. I need each batch to have vectors of one type. Therefore in the dataloader we have to:

  • pick 0 or 1 randomly;
  • pick a subset of the dataset where vectors end with the number generated in the previous step;
  • pick a random batch from the subset above;
    How would you implement it in PyTorch?
    Thanks in advance!

You can implement a custom Dataset with the logic that you described in __getitem__ (or an IterableDataset with the logic in __iter__). There is an example of writing a custom Dataset in the documentation.

1 Like

Thanks! Does __getitem__ returns the whole batch? Or a single item?

__getitem__ should return just a single item (consider that the Dataset is not expected to have any knowledge of what the downstream batch size specified by e.g., a DataLoader would be).

1 Like

Yes, that is what I was assuming, however, I couldn’t find any source on how to rewrite Dataloader to return batch conditionally

In theory, you can make __getitem__ conditionally return a batch, and then set batch_size = 1 for DataLoader, but this is outside the general use case of Dataset/DataLoader.

1 Like