Checkpoints not getting created with custom sampler


I am working on a multi-task model with uneven dataset size and have a custom sampler and using the sampler in dataloader (below)

sampler = BalancedBatchSchedulerSampler(dataset, batch_size)
dataloader = DataLoader(

BalancedBatchSchedulerSampler is the custom sampler. Also I have set replace_sampler_ddp to False. With this custom sampler I don’t see the checkpoint folder is getting created for the model. When I don’t pass the sampler argument and use the default RandomSampler the checkpoint is getting created without any other change in the code.

Is it possible that the sampler is affecting the model checkpoint somehow?

Thank you!

Maybe your custom sampler somehow has very large size, And you won’t reach the end of it. check the implemention of __len__ method.

Thanks so much for pointing to the __len__ method. I have a multi-task model and the three tasks have different sizes. The __len__ method is implemented to return equal proportion of samples from each dataset. I went through PyTorch documentation on Sampler and dataloader but not sure what should I change. Did you mean that the number of samples are too large?

Thank you

import math
from random import shuffle
import torch
from import RandomSampler

class BalancedBatchSchedulerSampler(
    iterate over tasks and provide a balanced batch per task in each mini-batch

    def __init__(self, dataset, batch_size):

        super(BalancedBatchSchedulerSampler, self).__init__(dataset)
        self.dataset = dataset
        self.batch_size = batch_size
        self.number_of_datasets = len(dataset.datasets)
        self.largest_dataset_size = max(
            [len(cur_dataset) for cur_dataset in dataset.datasets]

    def __len__(self):
        return (
            * math.ceil(self.largest_dataset_size / self.batch_size)
            * len(self.dataset.datasets)

    def __iter__(self):
        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            sampler = RandomSampler(cur_dataset)
            cur_sampler_iterator = sampler.__iter__()

        push_index_val = [0] + list(self.dataset.cumulative_sizes[:-1])
        step = self.batch_size
        samples_to_grab = math.ceil(self.batch_size / self.number_of_datasets)
        # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
        epoch_samples = self.largest_dataset_size * self.number_of_datasets

        final_samples_list = []  # this is a list of indexes from the combined dataset
        for _ in range(0, epoch_samples, step):
            cur_batch_samples = []
            for i in range(self.number_of_datasets):
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab):
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                    except StopIteration:
                        # got to the end of iterator - restart the iterator and continue to get samples
                        # until reaching "epoch_samples"
                        sampler_iterators[i] = samplers_list[i].__iter__()
                        cur_batch_sampler = sampler_iterators[i]
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]

        return iter(final_samples_list)

That was my guess.
Based on the code, I think you’re implementing a BatchSampler not a sampler.
I don’t know how you can fix your code,

I turned off the __len__ function and checkpoint is getting created. I think I will play around a bit with batch sampler.
Thanks again for the help!

maybe you can get what you want by playing with batchsize,
something like:

import torch
d1 =,2))
d2 =*torch.ones(200,3))
d3 =*torch.ones(300,4))
dl1 =, batch_size=5, shuffle=True)
dl2 =, batch_size=10, shuffle=True)
dl3 =, batch_size=15, shuffle=True)
dl =zip(dl1,dl2,dl3)
for m,n,p in dl:
    # do the job