BalancedBatchSchedulerSampler is the custom sampler. Also I have set replace_sampler_ddp to False. With this custom sampler I don’t see the checkpoint folder is getting created for the model. When I don’t pass the sampler argument and use the default RandomSampler the checkpoint is getting created without any other change in the code.
Is it possible that the sampler is affecting the model checkpoint somehow?
Thanks so much for pointing to the __len__ method. I have a multi-task model and the three tasks have different sizes. The __len__ method is implemented to return equal proportion of samples from each dataset. I went through PyTorch documentation on Sampler and dataloader but not sure what should I change. Did you mean that the number of samples are too large?
Thank you
import math
from random import shuffle
import torch
from torch.utils.data.sampler import RandomSampler
class BalancedBatchSchedulerSampler(torch.utils.data.sampler.Sampler):
"""
iterate over tasks and provide a balanced batch per task in each mini-batch
"""
def __init__(self, dataset, batch_size):
super(BalancedBatchSchedulerSampler, self).__init__(dataset)
self.dataset = dataset
self.batch_size = batch_size
self.number_of_datasets = len(dataset.datasets)
self.largest_dataset_size = max(
[len(cur_dataset) for cur_dataset in dataset.datasets]
)
def __len__(self):
return (
self.batch_size
* math.ceil(self.largest_dataset_size / self.batch_size)
* len(self.dataset.datasets)
)
def __iter__(self):
samplers_list = []
sampler_iterators = []
for dataset_idx in range(self.number_of_datasets):
cur_dataset = self.dataset.datasets[dataset_idx]
sampler = RandomSampler(cur_dataset)
samplers_list.append(sampler)
cur_sampler_iterator = sampler.__iter__()
sampler_iterators.append(cur_sampler_iterator)
push_index_val = [0] + list(self.dataset.cumulative_sizes[:-1])
step = self.batch_size
samples_to_grab = math.ceil(self.batch_size / self.number_of_datasets)
# for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
epoch_samples = self.largest_dataset_size * self.number_of_datasets
final_samples_list = [] # this is a list of indexes from the combined dataset
for _ in range(0, epoch_samples, step):
cur_batch_samples = []
for i in range(self.number_of_datasets):
cur_batch_sampler = sampler_iterators[i]
cur_samples = []
for _ in range(samples_to_grab):
try:
cur_sample_org = cur_batch_sampler.__next__()
cur_sample = cur_sample_org + push_index_val[i]
cur_samples.append(cur_sample)
except StopIteration:
# got to the end of iterator - restart the iterator and continue to get samples
# until reaching "epoch_samples"
sampler_iterators[i] = samplers_list[i].__iter__()
cur_batch_sampler = sampler_iterators[i]
cur_sample_org = cur_batch_sampler.__next__()
cur_sample = cur_sample_org + push_index_val[i]
cur_samples.append(cur_sample)
cur_batch_samples.extend(cur_samples)
shuffle(cur_batch_samples)
final_samples_list.extend(cur_batch_samples)
return iter(final_samples_list)