I’m trying to implement multi domain learning using pytorch. The problem is that I need that all the samples within a batch to be of the same domain. I will have a csv file containing the information about the domain of each sample.
Is there a way to select only the samples of the same domain to create a batch?
There are more choices you will need to make, e.g. what happens if each domain has a different number of samples, or the number of samples is not divisible by the batch size. I would subclass torch.utils.data.Sampler to write custom logic for creating batches that conform to your logic.
As an alternative, you could construct a dataset per each domain, have a loader per domain, and juggle between these dataloaders (e.g. use round-robin).