I have a use case where I do forward for each sample in a batch and only accumulate loss for some of the samples based on some condition on the model output of the sample. Here is an illustrating code,
for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() total_loss = 0 loss_count_local = 0 for i in range(len(target)): im = Variable(data[i].unsqueeze(0).cuda()) y = Variable(torch.FloatTensor([target[i]]).cuda()) out = model(im) # if out satisfy some condtion, we will calculate loss # for this sample, else proceed to next sample if some_condition(out): loss = criterion(out, y) else: continue total_loss += loss loss_count_local += 1 if loss_count_local == 32 or i == (len(target)-1): total_loss /= loss_count_local total_loss.backward() total_loss = 0 loss_count_local = 0 optimizer.step()
My question is, as I do forward for all samples but only do backward for some of the samples. When will the graph for those samples which do not contribute to the loss be freed? Will these graphs be freed only after the for loop has ended or immediately after I do forward for the next sample? I am a little confused here.
Also for those samples that do contribute to the
total_loss, their graph will be freed immediately after we do
total_loss.backward(). Is that right?