There are two approaches you could implement to make sure every sample gets used in each training epoch:
- Define your
num_classes
dynamically based on how many classes remain that still have untrained samples. For example, if you use a list of numpy arrays to store indexers for each class.
And then define num_batches to also be dynamic:
num_batches = 4 if (num_batches <= num_classes) else num_classes
But this method may result in overfitting to the final class(es) remaining. For instance, if the class with the most elements is getting the last 3-5 batches with just that class, you might find the model overfits to that class.
- Keep cycling each class independently, while randomly sampling with full batchsize until all classes have completed at least one cycle. You could maintain a list of number of elements in each class and just subtract 1 each time a class is chosen. Then break training for that epoch when that condition is met.
num_elements_per_class = np.array([3, 6, 6, 11, 14, ... 20])
vect = torch.rand(num_classes)
vect_topk=vect.topk(num_batches)[1].numpy()
#deduct from classes sampled
num_elements_per_class[vect_topk] -= 1
And then your per epoch training loop might look like:
#break training if all sampled
while True:
#training tasks go here
if not np.any(dataloader.num_elements_per_class>0):
break
The second approach would be ideal to limit class overfitting. But you’d probably need to build your own custom dataloader as the vanilla Pytorch dataloader requires a __len__
.
Basically, you just need to index where your classes are in the dataset, and maintain a list of those indices, plus a list of those indices but shuffled. Something like the following might work, but you may need to debug it since I wrote it on my phone and haven’t checked:
#get list of indices for each class
def list_class_indices(labels):
return [np.where(labels==x)[0] for x in range(num_classes)]
#shuffle list of indices
rng = np.random.default_rng()
def shuffle_indices(indices_list):
return [rng.shuffle(x) for x in indices_list]
Maintain a list of indices_counter
and total_indices
:
def init_indices_counter(class_indices):
return np.zeros(len(class_indices)), np.array([len(x) for x in class_indices])
An update function for indices_counter
:
def update_indices_counter(indices_counter, total_indices, vect_topk):
indices_counter[vect_topk]+=1
condition = indices_counter==total_indices
#reset counter where cycle completed
indices_counter[condition]=0
return indices_counter
And then a get_batch
function:
def get_batch(data, labels, shuffled_indices, vect_topk, indices_counter, total_indices, num_elements_per_class):
indices = [x[y] for x, y in zip(shuffled_indices[vect_topk], indices_counter[vect_topk])]
indices_counter = update_indices_counter(indices_counter, total_indices, vect_topk)
#update num_elements_per_class
num_elements_per_class[vect_topk] -= 1
return data[indices], labels[indices]
And you will likely need a reset function between training epochs to shuffle indices and to reset the indices_counter and num_elements_per_class.
The above are independent definitions, but you could probably clean it up quite a bit with a class function.