Sampler with unique classes per batch

I would like to make a sampler for my dataloader. I have 12 unique classes in my dataset and it is really important that there is no more than one element of each class in each batch. It also doesn’t matter how big the batch size is as long as this requirement is fulfilled. I’ve tried the weighted random sampler, but it still gives double elements in 40% of cases (with batch size = 4). This is what I have for the weighted sampler but I don’t know where to go from here:

def get_targets(dataset):
    Get all labels in dataset
    targets = []
    for i in range(len(dataset)):
        sample = dataset[i]
    return targets

def class_weights(target):
    Get class weights
    unique_patients = np.unique(np.array(target))
    n_patients = len(unique_patients)
    print("Number of unique patients...", n_patients)
    patient_weights = {}
    for patient in unique_patients:
        sample_count = 0
        for n in range(0, len(target)):
            if target[n] == patient:
                sample_count +=1 
        patient_weights[patient] = 1/sample_count
    return patient_weights

def make_sampler(dataset):
    Make weighted sampler
    targets = get_targets(dataset)
    weight = class_weights(targets)
    samples_weight = np.array([weight[t] for t in targets])
    samples_weight = torch.from_numpy(samples_weight)
    sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
    return sampler

1 Like

If you know your batchsize, you could just make a random vector with the length of your classes, and then use a topk equal to your batchsize to set which classes to sample from.

num_classes = 12
num_batches = 4
vect = torch.rand(num_classes)

Thanks for the suggestion! I’ve tried it out and the solution works! I have realised that I need to add more code to make sure that every sample gets chosen. Also, the number of elements per class can vary from 3 elements to 20 elements. Do you have a suggestion for this? :slight_smile:

There are two approaches you could implement to make sure every sample gets used in each training epoch:

  1. Define your num_classes dynamically based on how many classes remain that still have untrained samples. For example, if you use a list of numpy arrays to store indexers for each class.

And then define num_batches to also be dynamic:

num_batches = 4 if (num_batches <= num_classes) else num_classes

But this method may result in overfitting to the final class(es) remaining. For instance, if the class with the most elements is getting the last 3-5 batches with just that class, you might find the model overfits to that class.

  1. Keep cycling each class independently, while randomly sampling with full batchsize until all classes have completed at least one cycle. You could maintain a list of number of elements in each class and just subtract 1 each time a class is chosen. Then break training for that epoch when that condition is met.
num_elements_per_class = np.array([3, 6, 6, 11, 14, ... 20])

vect = torch.rand(num_classes)

#deduct from classes sampled
num_elements_per_class[vect_topk] -= 1

And then your per epoch training loop might look like:

#break training if all sampled
while True:
    #training tasks go here
    if not np.any(dataloader.num_elements_per_class>0): 

The second approach would be ideal to limit class overfitting. But you’d probably need to build your own custom dataloader as the vanilla Pytorch dataloader requires a __len__.

Basically, you just need to index where your classes are in the dataset, and maintain a list of those indices, plus a list of those indices but shuffled. Something like the following might work, but you may need to debug it since I wrote it on my phone and haven’t checked:

#get list of indices for each class
def list_class_indices(labels):
    return [np.where(labels==x)[0] for x in range(num_classes)]

#shuffle list of indices
rng = np.random.default_rng()

def shuffle_indices(indices_list):
    return [rng.shuffle(x) for x in indices_list]

Maintain a list of indices_counter and total_indices:

def init_indices_counter(class_indices):
    return np.zeros(len(class_indices)), np.array([len(x) for x in class_indices])

An update function for indices_counter:

def update_indices_counter(indices_counter, total_indices, vect_topk):
    condition = indices_counter==total_indices
    #reset counter where cycle completed
    return indices_counter

And then a get_batch function:

def get_batch(data, labels, shuffled_indices, vect_topk, indices_counter, total_indices, num_elements_per_class):
    indices = [x[y] for x, y in zip(shuffled_indices[vect_topk], indices_counter[vect_topk])]
    indices_counter = update_indices_counter(indices_counter, total_indices, vect_topk)

    #update num_elements_per_class
    num_elements_per_class[vect_topk] -= 1
    return data[indices], labels[indices]

And you will likely need a reset function between training epochs to shuffle indices and to reset the indices_counter and num_elements_per_class.

The above are independent definitions, but you could probably clean it up quite a bit with a class function.


Thanks for all the tips! This is really helpful and it works so far. This is what I’ve got:

def list_class_indices(labels):
    class_indices = []
    for i in range(0, num_classes):
        class_idx = []
        for j in range(0, len(labels)):
            if labels[j] == i:
    return class_indices

def shuffle_indices(indices_list):
    shuffled_indices = []
    for i in range(0, len(indices_list)):
        shuffled_indices.append((sorted(indices_list[i], key=lambda k: random.random())))
    return shuffled_indices

def init_indices_counter(class_indices):
    return np.zeros(len(class_indices)), np.array([len(x) for x in class_indices])

def update_indices_counter(indices_counter, total_indices, vect_topk, indices, indices_list):
    condition = indices_counter == total_indices
    #reset counter where cycle completed
    # update indices list
    new_indices = []
    # Go through every class
    for i in range(0, len(indices_list)):
        class_idx = indices_list[i]
        new_class_idx = []
        # Go through every class indice
        for j in range(0, len(class_idx)):
            # Go through chosen indices and see if they match
            for k in range(0, len(indices)):
                if class_idx[j] != indices[k]:
    return indices_counter, new_indices

def get_indices(indices, vect_topk):
    data = []
    for i in vect_topk:
        x = indices[i][0]
    return data

def get_dataset_indices(dataset, indices):
    x = []
    for i in indices:
    return x

def get_batch(dataset, indices_list, indices_counter, total_indices):
    # Shuffle indices list
    indices_list = shuffle_indices(indices_list)
    # Get 64 random indices
    vect = torch.rand(num_classes)
    # Get 64 random indices
    indices = get_indices(indices_list, vect_topk)
    # Get these indices from dataset
    data = get_dataset_indices(dataset, indices)
    # Update counter
    indices_counter, new_indices = update_indices_counter(indices_counter, total_indices, vect_topk, indices, indices_list)
    return data, new_indices, indices_counter

indices_list = list_class_indices(targets)
counter, total_indices = init_indices_counter(indices_list)

for i in range(0, 40):
    data, indices_list, counter = get_batch(val_ds, indices_list, counter, total_indices)

Unfortunately, it is still quite slow and could significantly slow down training. Is there something I can improve on to make it go faster?

I was also trying to make a sampler out of it but it also seems quite slow.

class MySampler(
    def __init__(self, dataset, batch_size = 64):
        self.dataset = dataset
        self.batch_size = batch_size
        self.targets = get_targets(dataset)
        self.unique_targets = get_unique_patients(targets)
        self.num_classes = len(self.unique_targets)
        self.indices_list = list_class_indices(targets)
        self.counter, self.total_indices = init_indices_counter(indices_list)
    def __iter__(self):
        vect = torch.rand(num_classes)
        indices = get_indices(indices_list, vect_topk)
        self.counter, self.indices_list = update_indices_counter(indices_counter, total_indices, vect_topk, indices, indices_list)
        return iter(self.indices)

sampler = MySampler(val_ds, batch_size = 64)
val_dl = DataLoader(val_ds, sampler = sampler,  batch_size=64,   shuffle=False, drop_last=True)

Thanks for all the help! I appreciate it.

Anytime you are running a bunch of “for” iterations, the data is being processed synchronously, meaning one after the other.

I would just cast this to tensors and run it as a masking operation:

### BEFORE ###
def list_class_indices(labels):
    class_indices = []
    for i in range(0, num_classes):
        class_idx = []
        for j in range(0, len(labels)):
            if labels[j] == i:
    return class_indices

### AFTER ###
def list_class_indices(labels, num_classes):
    labels=torch.tensor(labels, dtype=torch.int32)
    indices = torch.arange(labels.size(0))
    return [indices[labels==i] for i in range(num_classes)]

And to test it:

x = torch.randint(0, 12, (100,)).tolist()

class_indices=list_class_indices(x, 12)


You should get out a tensor of the 0th class indices. What’s nice about this method, is you can also cast it to a GPU to take advantage of CUDA cores.

On a side note, anytime you’re using a bunch of definitions that work together, it’s probably best to create a class function. This way you can call self.value1 instead of having to pass in values into each definition. Just cleaner and has less code involved, easier to organize. For example:

class SuperIndexer():
    def __init__(self, labels, data, num_classes, ...):
        self.labels = torch.tensor(labels, dtype=torch.int32) = torch.tensor(data, dtype=torch.float32)
        self.num_classes = num_classes
    def list_class_indices(self)
        indices = torch.arange(self.labels.size(0))
        self.class_indices = [indices[self.labels==i] for i in range(self.num_classes)]

1 Like

Thank you!! That helps me a lot :slight_smile: