I’m doing some sequence decoding with a GRU network I’ve trained. I do this using a variant of beam search, storing the most likely sequences in a priority queue. My network works fine when I train and test my model as per usual (e.g. batching with fixed sequence lengths), however the trouble I’m having is that memory usage seems to be huge for this decoding function and I’m not quite sure why. It grows very quickly until I get the message “(interrupted by signal 9: SIGKILL)” and it dies.
My decoding method runs one step at a time with a batch size of 1 and a maximum sequence length of 200 steps to be fed into the network (although I’m storing the entirety of the sequence in the priority queue). It should run for about 100,000 time steps, but dies around 10000.
I’ve used tracemalloc to try and see where all this memory is being allocated between a number of steps:
[ Top 10 ]
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py:115: size=52.5 KiB, count=13232, average=4 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py:56: size=52.4 KiB, count=13231, average=4 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py:63: size=51.9 KiB, count=13231, average=4 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py:61: size=51.9 KiB, count=13230, average=4 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py:62: size=51.7 KiB, count=13226, average=4 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py:57: size=51.7 KiB, count=13225, average=4 B
/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py:514: size=14.0 KiB, count=2274, average=6 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/threading.py:884: size=13.2 KiB, count=26, average=518 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/threading.py:1230: size=13.1 KiB, count=1373, average=10 B
/Users/askates/anaconda3/envs/torch/lib/python3.6/queue.py:160: size=9056 B, count=2264, average=4 B
The traceback of the top one is:
320400 memory blocks: 1252.4 KiB
File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 115
hidden = inner(input[i], hidden, *weight)
File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 86
hy, output = inner(input, hidden[l], weight[l], batch_sizes)
File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 243
nexth, output = func(input, hidden, weight, batch_sizes)
File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 323
return func(input, *fargs, **fkwargs)
File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 192
output, hidden = func(input, self.all_weights, hx, batch_sizes)
File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491
result = self.forward(*input, **kwargs)
File "/Users/askates/Documents/Projects/model/modules.py", line 112
out, h = self.gru(z, h)
File "/Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491
result = self.forward(*input, **kwargs)
File "/Users/askates/Documents/Projects/model/decoder.py", line 144
_, _, _, log_prob_transition, h = self.decoder_network(input_feed, state_feed)
File "/Users/askates/Documents/Projects/model/gumbel.py", line 318
sequence_list = beam_decoder.dynamic_beam_search()
File "/Users/askates/Documents/Projects/model/gumbel.py", line 396
decode_state_sequence(decoder, data, params)
File "/Users/askates/Documents/Projects/model/gumbel.py", line 436
main(args)
File "/Applications/PyCharm CE.app/Contents/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1068
pydev_imports.execfile(file, globals, locals) # execute the script
File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1658
globals = debugger.run(setup['file'], None, None, is_module)
File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1664
main()
I can post my code if need be, but it’s a complex model so it would not be trivial to post all relevant code. Any idea how I can go about further investigating the issue, if not guessing what the cause might be!
Thanks for your help.
Edit: added a different traceback.