I have a number of datasets which I have created as separate dataset classes and trying to perform multi-task training where each batch is sampled from each dataset inversely proportional to the size of the dataset to balanced out.
But for some reason I am facing an error when doing so.
i received a “TypeError: object of type ‘int’ has no len()” when accessing this particular code in
class BalancedSingleDatasetBatchSampler(torch.utils.data.Sampler):
def __init__(self, datasets, batch_size,drop_last=False):
self.datasets = datasets
self.batch_size = batch_size
total_dataset_len = sum(len(dataset) for dataset in datasets) <------- error
self.dataset_probabilities = []
for dataset in datasets:
self.dataset_probabilities.append(len(dataset) / total_dataset_len)
# self.dataset_probabilities = [len(dataset) / sum(len(dataset) for dataset in datasets) for dataset in datasets]
self.drop_last =drop_last
def __iter__(self):
dataset_indices = list(range(len(self.datasets))) # indices for each dataset
while True:
# Choose a dataset to draw a batch from, with probability proportional to the dataset's size
dataset_index = np.random.choice(dataset_indices, p=self.dataset_probabilities)
dataset = self.datasets[dataset_index]
sample_indices = list(range(len(dataset))) # indices for samples in this dataset
np.random.shuffle(sample_indices) # shuffle the indices if you want
for i in range(0, len(dataset), self.batch_size):
yield sample_indices[i:i+self.batch_size]
def __len__(self):
return sum(len(dataset) for dataset in self.datasets) // self.batch_size
given that the sampler is used in
train_datasets = [TextDataset(self.tokenizer, self.hparams.data_dir, self.hparams.max_seq_length,task_name=task,split = 'train') for task in self.hparams.tasks]
multi_train_dataset = ConcatDataset(train_datasets)
weighted_sampler = BalancedSingleDatasetBatchSampler(train_datasets,self.hparams.train_batch_size)
dataloader = DataLoader(multi_train_dataset,
batch_sampler = weighted_sampler,
collate_fn = TaskCollator(self.tokenizer),
# batch_size = self.hparams.train_batch_size,
num_workers = 8)
it does not make sense that the dataset is a int, I have checked the type outside of batchsampler and it is not an int, but strangely when checked inside the sampler code, it iteratively produces ints after the dataset class.