Accessing the loss from any module at backward time

Hello :slight_smile:

I’m trying to implement a function which operates on the raw model loss at backward time due to badly defined gradients. One naive approach is to define a global variable which contains the loss and access it as is, however playing around with global variables is a bit edgy.

As such, is there a mecanism that allows to get the model loss from the ctx in a few lines ? (i.e. not modifying the whole model)

I’m not sure I understand the question completely, but if you would like to grab the gradient at specific parameters, you could use hooks as:


and manipulate or print out the gradient.

Would this work or could you explain, what loss you would like to see?

Thanks for the reply, I’ll explain a bit more : let’s say I have a function such as

class MyRound(torch.autograd.Function):
    def forward(ctx,input):
        return torch.round(input)
    def backward(ctx,grad):
        return grad.clone()

In my use case the grad at backward is not well defined. Rather than operating on this badly defined grad I can take a look at the model loss (e.g. cross entropy) and redefine a grad that is useful. Can I access the cross-entropy loss (or whatever “root” loss from which the backpropagation started) from within backward ? (e.g. through the ctx object) Result would look like :

    def backward(ctx,grad):
        good_grad = ctx.root_loss 
        # and some operations to have the correct grad
        return good_grad