Ensure Batch Losses Have Low Entropy or Stdev in an Epoch

It is a common practice in RL to normalize discounted rewards when doing policy gradients. Similarily, I want the model to perform similarily across all of the batches. Hence, I would like to normalize losses in a similar way by substracting their mean and dividing to their standard deviation before summing losses. (Another approach could be multiplying the sum of the losses with their entropy, do you have any idea which one would be preferrable?). I would welcome if someone can enlighten me about how I can implement this in Pytorch. I already know that I can do zero_grad and optimizer.step() for each epoch rather than each batch. Would it work if I write the batch losses into a Tensor and then do their normalization before summing them? Would it work if I use loss.backwards on the summed tensor before optimizer.step() at the end of an epoch? Thank you very much.

I tried the approach but this time it didn’t fit the memory. I guess that PyTorch accumulates all batch information when I do this way and that doesn’t have the possibility of fitting into memory. Does anybody have a better idea to accomplish this? Thanks.

    losses = torch.zeros(len(train_loader),).float().cuda()
    for i, (batch_features, _, _, batch_targets) in enumerate(train_loader):

        outputs = model(features).float().cuda()

        batch_loss = custom_loss(outputs, batch_targets)

        losses[i] = batch_loss
    entro = losses / losses.sum()
    entro = (entro * torch.log(entro)).mean()
    loss = losses.mean() * entro * -1.0