Sampling using a network vs memory usage


I am trying to do sampling from a network (during training) in order to compute loss function. However, I am getting RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/ if the sampling times is too large, which I don’t understand.

The Pseudo-code (sorry, the actual code is very large):

#sampling_network is pre-trained and no updates during this training
for param in sampling_network.parameters():
        param.requires_grad = False


sampling_times = N
total_loss = 0.

sampling_inputs = main_network(inputs)
#do sampling
for sampling_time in range(sampling_times):
	prediction = sampling_network(sampling_inputs)
	loss = compute_loss(prediction)
	total_loss += loss

#fixed typo

In my case, the training is fine when N <=5, but throws “out of memory” error when N > 5.

What I don’t understand is the memory usage should be regardless with the setting of N as the same sampling_network is just simply called multiple times. Am i missing something?


What are you doing with total_loss?
Currently you are storing the computation graph in it.
If you just need it for printing, you should use:

total_loss += loss.item()

Or do you need it somewhere for a backward pass?

Sorry, typo, it should be


Also fixed in main thread.

OK, that makes sense.
The memory usage won’t stay the same, since for each pass a new computation graph is created and stored.
You could call .backward() in the for loop and optimizer.step() outside of it.

@ptrblck, Thanks. If I understand it correctly, calling .backward() with in the loop and step() outside of the loop will make the gradients to be computed at ever sampling time, and the trainable variables to be updated in the end of the sampling process. And this will have exactly the same effects (in terms of learning) to the network, but more memory efficient. Am I right?

Yes, you will save some memory but need more compute, since the gradients will be calculated in every step.
Besides that it should be identical.

Cool, thanks. I understand it now.