Hi,
I am trying to do sampling from a network (during training) in order to compute loss function. However, I am getting RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCStorage.cu:58 if the sampling times is too large, which I don’t understand.
The Pseudo-code (sorry, the actual code is very large):
#sampling_network is pre-trained and no updates during this training
for param in sampling_network.parameters():
param.requires_grad = False
optimizer.zero_grad()
sampling_times = N
total_loss = 0.
sampling_inputs = main_network(inputs)
#do sampling
for sampling_time in range(sampling_times):
prediction = sampling_network(sampling_inputs)
loss = compute_loss(prediction)
total_loss += loss
#fixed typo
total_loss.backward()
optimizer.step()
In my case, the training is fine when N <=5, but throws “out of memory” error when N > 5.
What I don’t understand is the memory usage should be regardless with the setting of N as the same sampling_network is just simply called multiple times. Am i missing something?
Thanks,