Help- Accumulate unreduced loss over several batches and do .backward()


I have this multi-label problem where I am using a masked BCE loss. For each sample, the loss values that are masked depend on the ground truth of the sample, so the mask is dynamic according to each specific example.

There are 18 labels and lets assume batch_size=4. Basically, after applying the mask, the loss of a batch looks something like this:

        [[0.7634, 0.4821, 0.5363, 1.0051, 0.7906, 1.2232, 0.5324, 0.6181, 0.5487,
         0.0000, 0.6234, 0.0000, 0.9262, 0.6945, 0.5609, 0.2120, 0.5893, 0.5999],
        [0.7202, 0.5397, 0.6313, 0.8765, 0.0000, 1.4007, 0.0000, 0.0000, 0.4348,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2584, 0.0000, 0.6737],
        [0.7303, 0.4027, 1.0703, 0.8589, 0.0000, 1.0411, 0.0000, 0.0000, 0.9066,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2357, 0.0000, 0.6243],
        [0.8076, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]

The first label is never masked, so for a given batch, I will always have 4 non-zero losses to average over and do a backward pass. The thing is that some labels will not contain 4 non-zero losses in a single batch, which is somewhat equivalent to reducing the batch size for those labels. Due to this, I intended to accumulate the losses of each label until I get 4 non-zero values and only then do a backward pass for that label and call optimizer.step().

I am having trouble implementing this. I basically tried to accumulate the loss matrix of every new batch using and then do accumulated_loss[:i] when the number of non-zero elements for column i reaches 4 but I get the RuntimeError

Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

Any help would be really appreciated.

I think you’re using backward before cat
this may help


Thanks for the reply. I am not using backward before cat.
I was trying something like this:

placeholder={x: torch.empty((0,1),device=device) for x in ["l1","l2","l3", (...... etc) , "l18"]}
for batch_idx, (inputs, labels,filenames) in loop:
    inputs =
    labels =

    # zero the parameter gradients
    # forward
    with torch.set_grad_enabled(phase == 'train'):
        outputs = model(inputs)
        loss= loss_raw * mask

#for each example in batch
for k in range(loss_masked.size(0)):
    #add the loss of each label to its placeholder



   placeholder ...... (etc) ............

    for l in ["l1","l2","l3", .... (etc) .... "l18"]:
        #if the number of non-zero elements of the placeholder is =4, do a backward pass
        # and clear the placeholder
        if placeholder[l][placeholder[l].sum(1)!=0].size(0)==4:      
            print("------    backprop of branch", l, "    ------")

#the loss of label 0 is always !=0             


PS- I know the unsqueezing would be unnecessary in this small working example but in reality the 18 labels are distributed by four groups which I want to cluster together.

Any ideas?

I don’t know much about it but,
retain_graph (bool ,* optional ) – If False , the graph used to compute the grads will be freed.
If you want to backward grad through the model multiple times, you should set retain_graph=True . Something like —> Loss.backward(retain_graph=True), and then you should manually free the graph.(I don’t know how)

I agree that loss.backward(retain_graph=True) + manually free the graph should work, but I also don’t know how to do that.

Thanks for your help!

I think this approach is way too complicated. You have a problem with (correct me if wrong) unbalanced loss, since some labels will have more non-zero losses than others. But IMO this is similar to unbalanced datasets. Why not just do weighted loss (number of non-zero losses per label / total non-zero losses)?

1 Like

Makes total sense and it is soooo much more obvious than the complicated solution I was going for. Sometimes all one needs is a fresh pair of eyes. Will let you know if it works as expected. Cheers