How to reduce memory usage when one large networks called many times

Hi, there, I am a newbie to PyTorch and I am really impressed by its speed and flexibility. Amazing work! But I have one question about the memory usage.

I noticed that PyTorch usually stores all immediate variables when building graph. The memory requirement could be very intense in the case where a large network is called many times in a for loop. Below is a piece of sample codes. There is only one network AntoEncoderDecoder. We iteratively feed data through it for many times. My understanding is PyTorch will store all immediate results in AntoEncoderDecoder and thus requires a lot of memory. Compared with its equivalent implementation of theano, PyTorch requires more GPU memory. In theano, I can use a batch size of 10. But only 2 in PyTorch. I am wondering is there anything special to take care of in order to reduce the memory usage in this case.

Any suggestions would be welcome! Thank you!

model = AntoEncoderDecoder()
output = Variable(torch.zeros(256,256))
for i in range(iterations):
    output = model( output )


Just to be sure I understand what you are trying to do:
You have an auto encoder model, a variable output and you want to apply you auto encoder iterations times to it and then backpropagate all the way through these iterations to learn an auto encoder that repetitively encodes and decodes multiple times before producing the final output? Here, your complete network is the auto encoder stacked iterations times.

When you run your model, we just keep in memory what is necessary to compute the backward pass, nothing more.
If you do not need to compute the backward pass, you can use the volatile option as follow output = Variable(torch.zeros(256,256), volatile=True) to tell pytorch that it does not need to keep the buffers for the backward. If you do so, calling .backward() on the output will obviously raise an error.

  1. We never store all intermediate variables, but only the ones that are absolutely necessary to compute the backward. Everything else is freed immediately.
  2. What kind of modules/operations are you using in your network? The increased memory usage might be coming from the fact that we’re always trying to pick fast algorithms in cuDNN, even at a cost of memory usage. We’ll be trying to make it possible to affect these choices in the future.

@albanD Thanks for your suggestion! Yes, your understanding is correct. Actually I do need to compute the backward pass. So the volatile option may not be applicable here.

@apaszke Thanks for the info! Currently my module functions like a rnn. So only simple matrix multiplication, slicing, and nonlinearity function are used but they repeated by many times due to the iteration.

@qianguih I don’t know of any memory issues of these ops, they shouldn’t differ that much from what Theano needs. Are you using 0.1.10?

Yes, it is 0.1.10. I guess theano has optimized every computation graph a lot during the function compilation.

Maybe. It’s hard to say what’s happening without having the model details.