Hi,
I’m using a code based on the Vision Transformer from timm and trying squeeze out the most of my available GPU memory. Often, my attempts crash at the following line:
pytorch-image-models/timm/models/vision_transformer.py at c241081251323dfc5e8dc799d49740c48cc9096f · huggingface/pytorch-image-models · GitHub which is the middle line here:
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
My goal in this post is not to debug that repo, and is a more general pytorch question.
Here’s the problem, as far as I can tell - attn
is a very large tensor. The softmax call tries to create a second tensor of the same size and assign it to attn
, but there exists a moment in time where both simultaneously reside in GPU memory (requiring double the GPU memory size) until the old attn
is replaced with the new one.
Assuming I’m right so far (am I?), I’m trying to see if this could be avoided. Specifically, I believe the problem could be reduced to calculating exp(attn)
(and normalizing, later, is less demanding). I know that calling torch.exp(x, out=x)
is not possible with autograd (as noted here), and as far as I understand the problem is that the input we discard is the output of the previous operator which might need it (as noted here, if I understand @tom correctly), but as far as I know, the dot product should not need to keep its output.
Is there anything that can be done here?
Thanks,
Yiftach