Do a batch in multiple iteration

Yes, you could delay the optimizer.step() using different approaches as described in this post.
Note that these approaches might yield the desired gradients, but e.g. the running stats in batch norm layers might suffer from the smaller batch sizes, so you might need to adjust the momentum.

Alternatively, you could use torch.utils.checkpointing to trade compute for memory and keep the high batch sizes.