Reduce memory footprint when processing same input for multiple linear layers

Consider the following:

layer = nn.Conv1d(64, 64, kernel_size=1, groups=4)

This layer represents 4 linear layers that run in parallel, where each of them processes vectors of 16 dimensions and returns vectors of the same dimensionality.

Now, suppose that we want each of this layers to process the same input x = torch.randn(N, 16).
To do that, we need to transpose and repeat the input as such:

x = x.t().repeat(4, 1)

The input will now have shape (64, N), and we could run layer(x).

For very large N or vector dimensions, this repeat operation would allocate a huge tensor which we know represents a single tensor repeated multiple times.
My question is: Can we avoid this memory footprint?

I currently don’t see a way to use a combination of expand and view to avoid this issue.

Since you are just repeating the input a plain nn.Conv1d using in_channels=16 would yield the same results as your grouped conv:

layer = nn.Conv1d(64, 64, kernel_size=1, groups=4, bias=False)

x = torch.randn(1, 16, 1)
x_ = x.repeat(1, 4, 1)
out_ = layer(x_)

conv = nn.Conv1d(16, 64, 1, bias=False)
with torch.no_grad():

out = conv(x)

print((out_ - out).abs().max())
# tensor(1.1921e-07, grad_fn=<MaxBackward1>)


Apparently, we could also just change the groups attribute in layer as layer.groups=1 and then run layer(x).

Would you see any issue in doing that?

Yes, I don’t believe this would directly work since you would increase the weight matrix from [64, 16, 1] to [64, 64, 1]. Of course you could then try to repeat the 16 filters again to create a kernel filter with the same weights, but this approach sounds quite wasteful since you would be:

  • repeating the filter kernels
  • repeating the input tensor
  • increasing the computational workload

without a necessity since nn.Conv1d(16, 64, 1, bias=False) will just work without any repetitions.

My previous reply was not clear enough.

Take a look at the following snippet:

import torch

layer = torch.nn.Conv1d(64, 64, kernel_size=1, groups=4, bias=True)
print('groups: ', layer.groups, 'weight: ', layer.weight.shape, 'bias: ', layer.bias.shape)

N = 100
x = torch.randn(16, N)
x_repeated = x.repeat(4, 1)

out_repeated = layer(x_repeated)

layer.groups = 1
print('groups: ', layer.groups, 'weight: ', layer.weight.shape, 'bias: ', layer.bias.shape)
out = layer(x)

print(torch.allclose(out, out_repeated, atol=1e-6))
print((out - out_repeated).abs().max())