Hello!
When passing the result of a torch.cat() into a Linear layer, the memory required to store the output is much greater than expected. For illustration purposes, I have a simplified example.
Here I have a matrix where I want to iteratively fill the values after feeding rows of x into a linear layer. I do not use torch.cat() here. The resulting matrix takes about 1 GB memory.
ff = nn.Linear(100, 1)
x = Variable(torch.randn((60, 200, 100)))
matrix = Variable(torch.zeros(60, 200, 800))
for i in range(matrix.size()[0]):
for j in range(matrix.size()[1]):
input = x[:, i, :]
matrix[:, i, j] = ff(input).squeeze()
In this next code black, I am doing the same task as above. The difference is that before feeding into the linear layer, I first pass the row of x into a torch.cat() followed by a view which returns the same input variable as the first example. This time, the resulting matrix takes about 10 GB memory.
ff = nn.Linear(100, 1)
x = Variable(torch.randn((60, 200, 100)))
matrix = Variable(torch.zeros(60, 200, 800))
for i in range(matrix.size()[0]):
for j in range(matrix.size()[1]):
input = torch.cat((x[:, i, :]), -1).view(60, 100) # input here is same as above
matrix[:, i, j] = ff(input).squeeze()
In practice what I am doing is using torch.cat to combine difference matrices. I’m guessing the explosion in memory usage is torch.cat() growing the number of connections in the computational graph. Is this the case and how can I go about avoiding this when using torch.cat()?