There are two approaches you could implement to make sure every sample gets used in each training epoch:

- Define your
`num_classes`

dynamically based on how many classes remain that still have untrained samples. For example, if you use a list of numpy arrays to store indexers for each class.

And then define num_batches to also be dynamic:

```
num_batches = 4 if (num_batches <= num_classes) else num_classes
```

But this method may result in overfitting to the final class(es) remaining. For instance, if the class with the most elements is getting the last 3-5 batches with just that class, you might find the model overfits to that class.

- Keep cycling each class independently, while randomly sampling with full batchsize until all classes have completed at least one cycle. You could maintain a list of number of elements in each class and just subtract 1 each time a class is chosen. Then break training for that epoch when that condition is met.

```
num_elements_per_class = np.array([3, 6, 6, 11, 14, ... 20])
vect = torch.rand(num_classes)
vect_topk=vect.topk(num_batches)[1].numpy()
#deduct from classes sampled
num_elements_per_class[vect_topk] -= 1
```

And then your per epoch training loop might look like:

```
#break training if all sampled
while True:
#training tasks go here
if not np.any(dataloader.num_elements_per_class>0):
break
```

The second approach would be ideal to limit class overfitting. But you’d probably need to build your own custom dataloader as the vanilla Pytorch dataloader requires a `__len__`

.

Basically, you just need to index where your classes are in the dataset, and maintain a list of those indices, plus a list of those indices but shuffled. Something like the following might work, but you may need to debug it since I wrote it on my phone and haven’t checked:

```
#get list of indices for each class
def list_class_indices(labels):
return [np.where(labels==x)[0] for x in range(num_classes)]
#shuffle list of indices
rng = np.random.default_rng()
def shuffle_indices(indices_list):
return [rng.shuffle(x) for x in indices_list]
```

Maintain a list of `indices_counter`

and `total_indices`

:

```
def init_indices_counter(class_indices):
return np.zeros(len(class_indices)), np.array([len(x) for x in class_indices])
```

An update function for `indices_counter`

:

```
def update_indices_counter(indices_counter, total_indices, vect_topk):
indices_counter[vect_topk]+=1
condition = indices_counter==total_indices
#reset counter where cycle completed
indices_counter[condition]=0
return indices_counter
```

And then a `get_batch`

function:

```
def get_batch(data, labels, shuffled_indices, vect_topk, indices_counter, total_indices, num_elements_per_class):
indices = [x[y] for x, y in zip(shuffled_indices[vect_topk], indices_counter[vect_topk])]
indices_counter = update_indices_counter(indices_counter, total_indices, vect_topk)
#update num_elements_per_class
num_elements_per_class[vect_topk] -= 1
return data[indices], labels[indices]
```

And you will likely need a reset function between training epochs to shuffle indices and to reset the indices_counter and num_elements_per_class.

The above are independent definitions, but you could probably clean it up quite a bit with a class function.