Torch DataLoader for a custom requirement

Hello! I am new to torch coding and can someone advice me on the following requirement.

I have an ordered dataset(shuffle=False) that is categorised into “bins”. I shall present an example on smaller scale that helps to clarify. Let’s say the size of dataset is 60 with bins of sizes 10,20,30. I want to train my model in the order of bins. (first with 10 then 20 and 30). I want my DataLoader to get data in batch_sizes of 8. In this case, after getting the first 8 datapoints, I don’t want to get the 2 remaining from bin-1 and get 6 from next one. What I want is to get only 2 and in the next iteration, get the 8 from bin-2. In short, I want to complete training in one bin first before moving to other. Also if batch_size happens to be greater than bin size, I want to get data in solely one bin before moving to next.

Can I please get some advice on how to do this? I could think of two ways: implementing a custom DataLoader(need advice on this too) or just create separate DataLoaders for each bin and while iterating with bins in the outermost loop, grab the corresponding DataLoader and do training. Will the latter method have some serious downsides?

I found the way! I ended up writing a custom BatchSampler class to feed into Dataloader.
Assume that bins =[10,30,60] (a cummulative count of bin sizes) and batch_size=8 as in my question.

class CustomBatchSampler(Sampler):
	def __init__(self, sampler, batch_size, bins):
		self.sampler = sampler
		self.batch_size = batch_size
		self.bins = bins
		self.num_batches = sum([((j-i)/batch_size).__ceil__() for i,j in zip([0] + self.bins, self.bins)])
		
	def __iter__(self):
		sampler_iter = iter(self.sampler)
		batch = [0] * self.batch_size
		idx_in_batch = 0
		bin_count = 0
		for total, idx in enumerate(self.sampler):
			batch[idx_in_batch] = idx
			idx_in_batch += 1
			if total+1 == self.bins[bin_count]:
				yield batch[:idx_in_batch]
				idx_in_batch = 0
				batch = [0] * self.batch_size
				bin_count += 1
			if idx_in_batch == self.batch_size:
				yield batch
				idx_in_batch = 0
				batch = [0] * self.batch_size

	def __len__(self):
		return self.num_batches

I used it like this:

data = MySet(##Dataset class)
batch_sampler = CustomBatchSampler(SequentialSampler(data), batch_size, bins)
loader = DataLoader(data, batch_sampler=batch_sampler)
for each in loader:
    print(len(each["label"]))

Output (which was expected):

8
2
8
8
4
8
8
8
6

It works in other cases too but I would like to know if this is safe?
Especially if I am doing some multiprocessing or distributed computing (iam using accelerate library).
Are there some expected behaviours in the BatchSampler that I am missing? (perhaps the num_batches? Or some dependency that iter(BatchSampler) gives lists of same length?).
Can someone please clarify this. I would be very grateful.

I believe your approach is valid and would have recommended the same.
I would not expect to see any issues in using your CustomBatchSampler with num_workers>0. The only issue I could see is to check if your model has a required minimal batch size (e.g. some models might fail with a batch size of 1 if e.g. batchnorm stats cannot be calculated).

Thanks a lot. I’ll take care with batch size :+1:.