VRAM explosion with Custom Linear

Hi everyone,
I’m struggling with a very very high memory consumption with a custom Module Linear that you can find here:


If I switch the w_r,w_i,w_j to a simple w and don’t perform any concatenation (actually it’s just a nn.linear), then the consumption is normal (equal to a nn.Linear, 6Go on 12Go) but when I use torch.concat, then OOM (12Go +). Is it normal that just a torch.cat operation blows the allocated memory ?

Thanks you !

Hi,

torch.cat has to create a new tensor of the size of the concatenated tensors, so it will use double the memory.

Hi,

So you’re saying that in memory I’ll have 4x(128x128) matrices and one (512x512) right ?

Ok I need to find something to alleviate this.

Yes, the memory will contain both