Disable gradient for loss based on result

I have a module that outputs 2 losses, l_1 and l_2.

I only want the gradients from l_2 to flow if l_2 is negative, but I still need l_2 for later computations. How can I achieve this?

You could just add a condition before calling backward:

if l_2 < 0:

Would that work in your use case or is it more complicated?

thanks for the quick answer. Is there another way to code this self-contained in the nn.Module subclass? Because this would get very, very complicated programming-wise (I have not started the project with this option in mind and combine them early). I have a lot of these modules and a lot of these “l_2” losses. I want to disable the gradients for positive values for some modules and not for others, based on hyper-parameters.

Could you give a small example, how these l_2 losses are calculated in the modules and where/how you would like to disable them?
Maybe detaching the critical tensors based on the condition would work.

Okay, I have modules module the loss for the main task and the absolute log-determinant for the transformation (for example, the sum of the absolute logs of the diagonal elements for multiplication with a triangular matrix. For non-linear transformation the whole thing is more difficult). Similar to normalizing flows, if you know this approach.

In my earlier layers, I don’t want the weights to optimize the absolute log-determinant because of numerical problems, I only want to avoid “low” log-determinants (< zero). But I don’t know how many layers to disable etc. or whether this would really solve my problem. So I have to experiment a lot.

All the log-determinants of the modules get combined via a sum and added to the overall loss. Since I have not started the problem with the approach in mind, I constantly combine the log-determinants to get better code-organization.

I essentiall have something like this:

class CustomClass(nn.Module):
    def forward(input):
         (do some calculations)
        return l_1, logdets

l_1 gets used immediatly in the next layer and I want to disable the gradient from logdets if it’s greater than zero.

If you would like to disable the gradient flowing back from logdets if it’s > 0, this code should probably work:

def forward(self, input):
    if logdets > 0:
        logdets = logdets.detach()
    return l_1, logdets

I think it should if I understood detach correctly, that’s brilliant!


1 Like

@ptrblck I revisited my problem. Your solution is unfortunately not working because logdets is a tensor where I want to let the gradient flow depending on the element-wise condition. i thought about the torch function interface but i am not sure whether it’s working since it provides gradients with respect to the output but I would need to manipulte the gradients flowing for the logdets computation.