Hi, there, I am a newbie to PyTorch and I am really impressed by its speed and flexibility. Amazing work! But I have one question about the memory usage.
I noticed that PyTorch usually stores all immediate variables when building graph. The memory requirement could be very intense in the case where a large network is called many times in a for loop. Below is a piece of sample codes. There is only one network AntoEncoderDecoder. We iteratively feed data through it for many times. My understanding is PyTorch will store all immediate results in AntoEncoderDecoder and thus requires a lot of memory. Compared with its equivalent implementation of theano, PyTorch requires more GPU memory. In theano, I can use a batch size of 10. But only 2 in PyTorch. I am wondering is there anything special to take care of in order to reduce the memory usage in this case.
Any suggestions would be welcome! Thank you!
model = AntoEncoderDecoder()
output = Variable(torch.zeros(256,256))
for i in range(iterations):
output = model( output )
Just to be sure I understand what you are trying to do:
You have an auto encoder model, a variable
output and you want to apply you auto encoder
iterations times to it and then backpropagate all the way through these iterations to learn an auto encoder that repetitively encodes and decodes multiple times before producing the final output? Here, your complete network is the auto encoder stacked
When you run your model, we just keep in memory what is necessary to compute the backward pass, nothing more.
If you do not need to compute the backward pass, you can use the
volatile option as follow
output = Variable(torch.zeros(256,256), volatile=True) to tell pytorch that it does not need to keep the buffers for the backward. If you do so, calling
.backward() on the output will obviously raise an error.
@albanD Thanks for your suggestion! Yes, your understanding is correct. Actually I do need to compute the backward pass. So the
volatile option may not be applicable here.
@apaszke Thanks for the info! Currently my module functions like a rnn. So only simple matrix multiplication, slicing, and nonlinearity function are used but they repeated by many times due to the iteration.
@qianguih I don’t know of any memory issues of these ops, they shouldn’t differ that much from what Theano needs. Are you using 0.1.10?
Yes, it is 0.1.10. I guess theano has optimized every computation graph a lot during the function compilation.
Maybe. It’s hard to say what’s happening without having the model details.