I am trying to do sampling from a network (during training) in order to compute loss function. However, I am getting RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCStorage.cu:58 if the sampling times is too large, which I don’t understand.
The Pseudo-code (sorry, the actual code is very large):
#sampling_network is pre-trained and no updates during this training for param in sampling_network.parameters(): param.requires_grad = False optimizer.zero_grad() sampling_times = N total_loss = 0. sampling_inputs = main_network(inputs) #do sampling for sampling_time in range(sampling_times): prediction = sampling_network(sampling_inputs) loss = compute_loss(prediction) total_loss += loss #fixed typo total_loss.backward() optimizer.step()
In my case, the training is fine when N <=5, but throws “out of memory” error when N > 5.
What I don’t understand is the memory usage should be regardless with the setting of N as the same sampling_network is just simply called multiple times. Am i missing something?