Implementing softmax or exp inplace

I’m using a code based on the Vision Transformer from timm and trying squeeze out the most of my available GPU memory. Often, my attempts crash at the following line:

pytorch-image-models/timm/models/ at c241081251323dfc5e8dc799d49740c48cc9096f · huggingface/pytorch-image-models · GitHub which is the middle line here:

            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)

My goal in this post is not to debug that repo, and is a more general pytorch question.

Here’s the problem, as far as I can tell - attn is a very large tensor. The softmax call tries to create a second tensor of the same size and assign it to attn, but there exists a moment in time where both simultaneously reside in GPU memory (requiring double the GPU memory size) until the old attn is replaced with the new one.
Assuming I’m right so far (am I?), I’m trying to see if this could be avoided. Specifically, I believe the problem could be reduced to calculating exp(attn) (and normalizing, later, is less demanding). I know that calling torch.exp(x, out=x) is not possible with autograd (as noted here), and as far as I understand the problem is that the input we discard is the output of the previous operator which might need it (as noted here, if I understand @tom correctly), but as far as I know, the dot product should not need to keep its output.

Is there anything that can be done here?


This is an interesting idea for saving memory. Keep in mind, though, that the tensor will typically be deallocated after it goes out of scope, so we are indeed talking about the time when the other softmax is computed (my inplace blogpost has some more detail).

As slow way to do/try the inplace softmax is to compute the row-by-row softmax in a for loop out of place and then copy it into the larger tensor. Then you can wrap this inplace softmax in an autograd function and implement the backward (if you don’t want to derive the formula yourself, you could use the PyTorch code ( or some reference).

Best regards


That’s a very interesting workaround. Thanks @tom!