Cuda OOM during loss backward

Hi,
I’m facing an issue of running out of cuda memory during backward propagation.

I need to train a small model using the following loss which requires significant amount of memory: e1 and e2 are of dimension (32000, 256). So there should be at least 32000*32000*32 * 3 bits = 12G memory. However the only GPU that I have access to has ~11G memory. (I have multiple such GPUs, but I didn’t figure out any method solving the problem by exploiting them.)

def loss(e1, e2):
    ee1 = e1@e1.T
    ee2 = e2@e2.T
    ee1 -= ee2
    ee1 = torch.pow(ee1, 2)
    return torch.sqrt(torch.sum(ee1))

I tried various ways calculating the loss or some parts of it in CPU memory, but the backward pass will still cause CUDA OOM. It seems still requiring 12GB during backward pass since the computation graph is the same no matter how do I divide and conquer the loss.
So I’m wondering is using this loss training models even feasible with my current devices?

Thanks