Concatenating an expanded view of a tensor

I’ve got a module (mostly) out of my control that I’d like to apply to a special tensor of inputs. The input tensor is of shape (batch_size, n_points, 3 + context_dim), where each element of the batch of shape (n_points, 3 + context_dim) looks like

[x_1 | context],
[x_2 | context],
[x_i | context]
[x_n | context]

Here I’ve used | to indicate concatenation of vectors.

For reference, context_dim is 256, and n_points is 2048. Each x_i is a 3 dimensional vector, and context is the same vector within an element of the batch.

My issue is that this tensor takes up a lot of memory if it’s constructed naively.

I’m familiar with striding and know that it’s possible to expand the context vector to the right shape for concatenation without any extra allocation, but if I understand correctly, the concatenation operation will require a full allocation of this memory, duplicating the context vector n_points times.

From what I can tell, this RFC might solve my problem, since I could combine the tensors into a NestedTensor, but it seems a long way off.

How can I reduce memory consumption in this situation with current versions of PyTorch?

In a similar situation, as my module started with a fully-connected layer, I just used matmul distributivity to do: linear([x_1 | context]; w0) + bias = linear(x_1; w1) + linear(context; w2) + bias. You can actually do inplace summation, and output from “linear(context)” is a smaller non-expanded tensor.

If your scenario is different, and you can’t process vector parts independently, I don’t this you can do anything about memory, as you have to intersperse data from two sources to pass to your module.

1 Like