Torch.cat() massive bottleneck?!

I was training something that looks like a transformer, and of course, this requires that you accumulate a tensor of previous tokens/other things in order to compute attention. This is facilitated by torch.cat() at each timestep, which i have found slows down my code by a factor of 2, while also increasing memory usage by a factor of 2.

It makes sense that making new arrays is differentiable whereas inplace assignment to a preallocated array is not, but is this really the best we can do?! Since so many GPU hours are spent on training transformers, i find it very hard to believe that such a huge bottleneck like this is so widely adopted, and nobody is looking into alternatives.

If you are using torch.cat in a loop, you should instead append all tensors to a list and create the tensor out side of the loop. Also, you could try to preallocate the output tensor and fill it. If this operation would break your training, you should get an error and would have to fall back to the first approach.

1 Like

This was my approach before incorporating attention - but the issue is that computing attention autoregressively requires that you use the output from all previous timesteps in the next timestep.

I could perhaps live with the .cat bottleneck, but i still cant help but wonder, what makes inplace assignment such a no-no? not just tensor[:, 1] = something, but tensor[:, 1] += something. It seems like as long as this operation is kept track of, this operation could be differentiable, and there would be no need to make duplicate tensors.

Inplace operations are not allowed where the original input is needed to calculate the gradients.
E.g. in a vanilla CNN you could try to use the inplace nn.ReLU as long as PyTorch doesn’t raise an error.
If it does, then the input needs to be unchanged for Autograd.