Help understanding PyTorch memory model

Hi I don’t really understand when pytorch internally allocates memory / does copies.

I’m using the fairseq multihead attention code https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py to build a transformer model that can do incremental encoding. However, I notice encoding with incremental state is 5x slower than encoding the entire sequence at once.

My current hypothesis is its because of this line


This line cats the new key row to the end of the previously saved key matrix. When this operation is done, does pytorch internally do a malloc, copy the old matrix over, and then copy the new row over? If so, is there a way to pre allocate space for rows like this to avoid a copy when cat is called?

k will be a new tensor and thus use new memory for the concatenated tensors.
PyTorch uses a caching allocator internally, which should reuse the GPU memory (assuming you are using the GPU), but this operation would still copy the tensors.
You could append the tensors to a list and call torch.cat or torch.stack at the end, if possible.

Thanks @ptrblck.

If I call cat on a list of tensors repeatedly, won’t each cat call require a new malloc + copy? The list would have a new element for every input token in the sequence.

I was planning on making a much larger (say, 500 column tensor) and fill in columns of that tensor one by one. And resize to a larger tensor when my original one was full.

Yes, calling torch.cat multiple times would create new tensors, so you could append all tensors to a list and create the tensor once it’s ready.
Based on your description, it seems the size of the final tensor is unknown and you would need to reallocate new memory nevertheless (by calling torch.cat multiple times)?

Yea my plan was to just preallocate a tensor with say 500 rows and double it in size every time I need to resize so that I don’t end up doing quite as many copies. I was looking for the cheapest way to do that.

Unfortunately at each input step I need to construct the cat’ed tensor ( in order to compute the attention scores at that timestep) so I can’t just use a list and cat it once at the end.