Change batch size based on loss

How do I increase the batch size when loss decreases and decrease the batch size when loss increases, that is dynamic batch size based on loss during training and evaluation?

I have no idea. Maybe

class Dataset:
    def __init__(self, data_path): = load_data(data_path)
        self.data_len = len(
        self.indices = np.arange(0, self.data_len)
        self.current_index =  0

    def reset(self):
        self.current_index = 0

    def get_next(batch_size):
        end_index = min(current_index + batch_size, self.data_len)
        indices = self.indices[self.current_index:end_index]
        x, y = get_batch_data(indices)
        self.current_index = end_index
        return x, y

batch_size = 64
last_loss = 50
dataset = Dataste(data_path)
for epoch in range(0, max_epoch):
    while dataset.current_index < dataset.data_len:

        x, y = dataset.get_next(batch_size)
        y_pred = model(x)
        loss = loss_function(y, y_pred)

        batch_size = batch_size + 5 if loss.item() < last_loss else batch_size - 5
        last_loss = loss.item()


Or you can just redifine dataloader each iteration, but I don’t know if this operation will affect the efficiency.

1 Like