I’ve got a module (mostly) out of my control that I’d like to apply to a special tensor of inputs. The input tensor is of shape
(batch_size, n_points, 3 + context_dim), where each element of the batch of shape
(n_points, 3 + context_dim) looks like
[x_1 | context], [x_2 | context], ... [x_i | context] ... [x_n | context]
Here I’ve used
| to indicate concatenation of vectors.
context_dim is 256, and
n_points is 2048. Each
x_i is a 3 dimensional vector, and context is the same vector within an element of the batch.
My issue is that this tensor takes up a lot of memory if it’s constructed naively.
I’m familiar with striding and know that it’s possible to expand the context vector to the right shape for concatenation without any extra allocation, but if I understand correctly, the concatenation operation will require a full allocation of this memory, duplicating the context vector
From what I can tell, this RFC might solve my problem, since I could combine the tensors into a NestedTensor, but it seems a long way off.
How can I reduce memory consumption in this situation with current versions of PyTorch?