Sampling from a concatenated dataset

Here is the code I am using :

trainset=ConcatDataset([(datasets.ImageFolder(train_dir, train_transform)),datasets.ImageFolder(train_dir, valid_transform)])                                
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)

what I would like to do is:

trainset=ConcatDataset([(datasets.ImageFolder(train_dir, train_transform)),datasets.ImageFolder(train_dir, valid_transform)])                                
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,sampler=any_sampler_function(trainset),shuffle=False)

I think I tried WeightedRandomSampler but the error I get is the same as using the custom sampler, something like “a tuple was provided while a sampler object is required”.