Torch.cat() blows up memory required

Hello!

When passing the result of a torch.cat() into a Linear layer, the memory required to store the output is much greater than expected. For illustration purposes, I have a simplified example.

Here I have a matrix where I want to iteratively fill the values after feeding rows of x into a linear layer. I do not use torch.cat() here. The resulting matrix takes about 1 GB memory.

ff = nn.Linear(100, 1)
x = Variable(torch.randn((60, 200, 100)))
matrix = Variable(torch.zeros(60, 200, 800))

for i in range(matrix.size()[0]):
    for j in range(matrix.size()[1]):
        input = x[:, i, :]
        matrix[:, i, j] = ff(input).squeeze()

In this next code black, I am doing the same task as above. The difference is that before feeding into the linear layer, I first pass the row of x into a torch.cat() followed by a view which returns the same input variable as the first example. This time, the resulting matrix takes about 10 GB memory.

ff = nn.Linear(100, 1)
x = Variable(torch.randn((60, 200, 100)))
matrix = Variable(torch.zeros(60, 200, 800))

for i in range(matrix.size()[0]):
    for j in range(matrix.size()[1]):
        input = torch.cat((x[:, i, :]), -1).view(60, 100) # input here is same as above
        matrix[:, i, j] = ff(input).squeeze()

In practice what I am doing is using torch.cat to combine difference matrices. I’m guessing the explosion in memory usage is torch.cat() growing the number of connections in the computational graph. Is this the case and how can I go about avoiding this when using torch.cat()?

Yes, the explosion in memory is torch.cat() growing the number of connections.

What’s happening in this specific case is that you’re calling torch.cat with a Variable argument instead of an iterable of Variables. torch.cat indexes its input a few times which adds to the computation graph.

Here are my profiling results (I’m not sure why my numbers are 10x less than yours, but the ratio between the memory consumptions seems okay):

  • (129mb) First code block, on my machine
  • (1.04gb) Code block with cat as you wrote it. This looks bad, as you mentioned.
  • (411mb) Replaced torch.cat((x[:, i, :]), -1).view(60, 100) with x[:, i, :].clone(). This is what the ideal memory use should be close to because all the input's are staying around in memory.
  • (415mb) Replaced torch.cat((x[:, i, :]), -1).view(60, 100) with torch.cat([x[:, i, :]]. This is what I would advise changing the code to to avoid the memory jump.

The advice here is that the argument to torch.cat shouldn’t be a Variable, it should be an iterable of Variables.
However, you probably aren’t actually calling torch.cat with a Variable because that doesn’t do anything. Could you provide a more detailed example of what you’re doing and how the memory is blowing up?