Hello,
I have this multi-label problem where I am using a masked BCE loss. For each sample, the loss values that are masked depend on the ground truth of the sample, so the mask is dynamic according to each specific example.
There are 18 labels
and lets assume batch_size=4
. Basically, after applying the mask, the loss of a batch looks something like this:
[[0.7634, 0.4821, 0.5363, 1.0051, 0.7906, 1.2232, 0.5324, 0.6181, 0.5487,
0.0000, 0.6234, 0.0000, 0.9262, 0.6945, 0.5609, 0.2120, 0.5893, 0.5999],
[0.7202, 0.5397, 0.6313, 0.8765, 0.0000, 1.4007, 0.0000, 0.0000, 0.4348,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2584, 0.0000, 0.6737],
[0.7303, 0.4027, 1.0703, 0.8589, 0.0000, 1.0411, 0.0000, 0.0000, 0.9066,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2357, 0.0000, 0.6243],
[0.8076, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]
The first label is never masked, so for a given batch, I will always have 4 non-zero losses to average over and do a backward pass. The thing is that some labels will not contain 4 non-zero losses in a single batch, which is somewhat equivalent to reducing the batch size for those labels. Due to this, I intended to accumulate the losses of each label until I get 4 non-zero values and only then do a backward pass for that label and call optimizer.step()
.
I am having trouble implementing this. I basically tried to accumulate the loss matrix of every new batch using torch.cat()
and then do accumulated_loss[:i]
when the number of non-zero elements for column i
reaches 4 but I get the RuntimeError
Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.
Any help would be really appreciated.