Memory usage increases each step until eventually runs out of memory

I’m doing some sequence decoding with a GRU network I’ve trained. I do this using a variant of beam search, storing the most likely sequences in a priority queue. My network works fine when I train and test my model as per usual (e.g. batching with fixed sequence lengths), however the trouble I’m having is that memory usage seems to be huge for this decoding function and I’m not quite sure why. It grows very quickly until I get the message “(interrupted by signal 9: SIGKILL)” and it dies.

My decoding method runs one step at a time with a batch size of 1 and a maximum sequence length of 200 steps to be fed into the network (although I’m storing the entirety of the sequence in the priority queue). It should run for about 100,000 time steps, but dies around 10000.

I’ve used tracemalloc to try and see where all this memory is being allocated between a number of steps:

[ Top 10 ]
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/ size=52.5 KiB, count=13232, average=4 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/ size=52.4 KiB, count=13231, average=4 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/ size=51.9 KiB, count=13231, average=4 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/ size=51.9 KiB, count=13230, average=4 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/ size=51.7 KiB, count=13226, average=4 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/ size=51.7 KiB, count=13225, average=4 B
/Applications/PyCharm size=14.0 KiB, count=2274, average=6 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/ size=13.2 KiB, count=26, average=518 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/ size=13.1 KiB, count=1373, average=10 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/ size=9056 B, count=2264, average=4 B

The traceback of the top one is:

320400 memory blocks: 1252.4 KiB
  File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/", line 115
    hidden = inner(input[i], hidden, *weight)
  File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/", line 86
    hy, output = inner(input, hidden[l], weight[l], batch_sizes)
  File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/", line 243
    nexth, output = func(input, hidden, weight, batch_sizes)
  File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/", line 323
    return func(input, *fargs, **fkwargs)
  File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/", line 192
    output, hidden = func(input, self.all_weights, hx, batch_sizes)
  File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/", line 491
    result = self.forward(*input, **kwargs)
  File "/Users/askates/Documents/Projects/model/", line 112
    out, h = self.gru(z, h)
  File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/", line 491
    result = self.forward(*input, **kwargs)
  File "/Users/askates/Documents/Projects/model/", line 144
    _, _, _, log_prob_transition, h = self.decoder_network(input_feed, state_feed)
  File "/Users/askates/Documents/Projects/model/", line 318
    sequence_list = beam_decoder.dynamic_beam_search()
  File "/Users/askates/Documents/Projects/model/", line 396
    decode_state_sequence(decoder, data, params)
  File "/Users/askates/Documents/Projects/model/", line 436
  File "/Applications/PyCharm", line 18
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Applications/PyCharm", line 1068
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Applications/PyCharm", line 1658
    globals =['file'], None, None, is_module)
  File "/Applications/PyCharm", line 1664

I can post my code if need be, but it’s a complex model so it would not be trivial to post all relevant code. Any idea how I can go about further investigating the issue, if not guessing what the cause might be!

Thanks for your help.

Edit: added a different traceback.

This is just a guess, but I ran into this problem when a tensor was continuously collecting its history due to requires_grad=True although I did not intend to differentiate it afterwards.

I had an inkling that this was the issue, but manually setting requires_grad to true didn’t seem to work. I assume I must have missed one, as wrapping it all in with torch.no_grad(): seemed to work!


The usual suspects here are the accumulation of some training statistics like the loss that is missing a .detach() or a .item() and for RNN models, the hidden layers that are not reseted properly (still some link to the past) between runs.