I have a use case where I do forward for each sample in a batch and only accumulate loss for some of the samples based on some condition on the model output of the sample. Here is an illustrating code,
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
total_loss = 0
loss_count_local = 0
for i in range(len(target)):
im = Variable(data[i].unsqueeze(0).cuda())
y = Variable(torch.FloatTensor([target[i]]).cuda())
out = model(im)
# if out satisfy some condtion, we will calculate loss
# for this sample, else proceed to next sample
if some_condition(out):
loss = criterion(out, y)
else:
continue
total_loss += loss
loss_count_local += 1
if loss_count_local == 32 or i == (len(target)-1):
total_loss /= loss_count_local
total_loss.backward()
total_loss = 0
loss_count_local = 0
optimizer.step()
My question is, as I do forward for all samples but only do backward for some of the samples. When will the graph for those samples which do not contribute to the loss be freed? Will these graphs be freed only after the for loop has ended or immediately after I do forward for the next sample? I am a little confused here.
Also for those samples that do contribute to the total_loss
, their graph will be freed immediately after we do total_loss.backward()
. Is that right?