High memory usage while training

Hello all,
I train a simple RNN network to predict a label on each input timestep on a huge random dataset.
I record memory usage while training, and notice that it is increasing linearly with dataset size:

(VSIZE = Virtual Memory recorded by Ubuntu, %MEM: How much % RAM it takes, x-axis = time in second)
My training script for reference:

class testNet(nn.Module):
    def __init__(self):
        super(testNet, self).__init__()
        self.rnn = nn.RNN(input_size=200, hidden_size=1000, num_layers=1)
        self.linear = nn.Linear(1000, 100)

    def forward(self, x, init):
        x = self.rnn(x, init)[0]
        y = self.linear(x.view(x.size(0)*x.size(1), x.size(2)))
        return y.view(x.size(0), x.size(1), y.size(1))

net = testNet()
init = Variable(torch.zeros(1, 4, 1000))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

total_loss = 0.0
for i in range(10000): #10000 mini-batch
    input = Variable(torch.randn(1000, 4, 200)) #Seqlen = 1000, batch_size = 4, feature = 200
    target = Variable(torch.LongTensor(4, 1000).zero_())

    optimizer.zero_grad()
    output = net(input, init)
    loss = criterion(output.view(-1, output.size(2)), target.view(-1))
    loss.backward()
    optimizer.step()
    total_loss += loss[0]

print(total_loss)

I expect memory usage not increasing per mini-batch. What might be the problem? (Correct me if my script is wrong)

1 Like

hi @NgPDat

I’m trying to reproduce your results.

Can you tell me the units of VSIZE? Is it bytes?
And %MEM, is it a percentage of the system memory?

So far, my run is pretty stable at around 105MB, after 400 mini-batches, I will wait for some time.

I think I see the problem. You have to remember that loss is a Variable, and indexing Variables, always returns a Variable, even if they’re 1D! So when you do total_loss += loss[0] you’re actually making total_loss a Variable, and adding more and more subgraphs to its history, making it impossible to free them, because you’re still holding a reference. Just replace total_loss += loss[0] with total_loss += loss.data[0] and it should be back to normal.

14 Likes

Work like a charm! Thank you.
I think I have better understanding of Variable now.

VSIZE is in kilobytes.
Yes, %MEM is percentage of the system memory.

The whole script for recording memory usage is from here: http://stackoverflow.com/questions/7998302/graphing-a-processs-memory-usage