I’ve got a module (mostly) out of my control that I’d like to apply to a special tensor of inputs. The input tensor is of shape (batch_size, n_points, 3 + context_dim)
, where each element of the batch of shape (n_points, 3 + context_dim)
looks like
[x_1 | context],
[x_2 | context],
...
[x_i | context]
...
[x_n | context]
Here I’ve used |
to indicate concatenation of vectors.
For reference, context_dim
is 256, and n_points
is 2048. Each x_i
is a 3 dimensional vector, and context is the same vector within an element of the batch.
My issue is that this tensor takes up a lot of memory if it’s constructed naively.
I’m familiar with striding and know that it’s possible to expand the context vector to the right shape for concatenation without any extra allocation, but if I understand correctly, the concatenation operation will require a full allocation of this memory, duplicating the context vector n_points
times.
From what I can tell, this RFC might solve my problem, since I could combine the tensors into a NestedTensor, but it seems a long way off.
How can I reduce memory consumption in this situation with current versions of PyTorch?