Adam tends to be higher memory allocation than SGD: Optimizers memory usage
You could either:
- try reducing the number of parameters,
- set your model, train data and labels/targets to half precision, i.e.
model.to(dtype=torch.float16)
- Use more memory efficient attention: GitHub - lucidrains/memory-efficient-attention-pytorch: Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"
Update to 3: I see Pytorch has implemented a memory efficient attention in Torch 2.0. See here: torch.nn.functional.scaled_dot_product_attention — PyTorch master documentation