Batch sampler for a multitask dataset

I have a number of datasets which I have created as separate dataset classes and trying to perform multi-task training where each batch is sampled from each dataset inversely proportional to the size of the dataset to balanced out.
But for some reason I am facing an error when doing so.
i received a “TypeError: object of type ‘int’ has no len()” when accessing this particular code in

class BalancedSingleDatasetBatchSampler(torch.utils.data.Sampler):
    def __init__(self, datasets, batch_size,drop_last=False):
        self.datasets = datasets
        self.batch_size = batch_size
        total_dataset_len = sum(len(dataset) for dataset in datasets) <------- error 
        self.dataset_probabilities = []
        for dataset in datasets:
            self.dataset_probabilities.append(len(dataset) / total_dataset_len)
        # self.dataset_probabilities = [len(dataset) / sum(len(dataset) for dataset in datasets) for dataset in datasets]
        self.drop_last =drop_last

    def __iter__(self):
        dataset_indices = list(range(len(self.datasets)))  # indices for each dataset
        while True:
            # Choose a dataset to draw a batch from, with probability proportional to the dataset's size
            dataset_index = np.random.choice(dataset_indices, p=self.dataset_probabilities)
            dataset = self.datasets[dataset_index]
            sample_indices = list(range(len(dataset)))  # indices for samples in this dataset
            np.random.shuffle(sample_indices)  # shuffle the indices if you want
            for i in range(0, len(dataset), self.batch_size):
                yield sample_indices[i:i+self.batch_size]

    def __len__(self):
        return sum(len(dataset) for dataset in self.datasets) // self.batch_size

given that the sampler is used in

train_datasets  = [TextDataset(self.tokenizer, self.hparams.data_dir, self.hparams.max_seq_length,task_name=task,split = 'train') for task in self.hparams.tasks]
multi_train_dataset = ConcatDataset(train_datasets)
weighted_sampler = BalancedSingleDatasetBatchSampler(train_datasets,self.hparams.train_batch_size)
dataloader = DataLoader(multi_train_dataset, 
                                    batch_sampler = weighted_sampler,
                                    collate_fn = TaskCollator(self.tokenizer),
                                    # batch_size = self.hparams.train_batch_size,
                                    num_workers = 8)

it does not make sense that the dataset is a int, I have checked the type outside of batchsampler and it is not an int, but strangely when checked inside the sampler code, it iteratively produces ints after the dataset class.