Model design results in memory explosion

Adam tends to be higher memory allocation than SGD: Optimizers memory usage

You could either:

  1. try reducing the number of parameters,
  2. set your model, train data and labels/targets to half precision, i.e. model.to(dtype=torch.float16)
  3. Use more memory efficient attention: GitHub - lucidrains/memory-efficient-attention-pytorch: Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Update to 3: I see Pytorch has implemented a memory efficient attention in Torch 2.0. See here: torch.nn.functional.scaled_dot_product_attention — PyTorch master documentation