Hello
I’m trying to implement a function which operates on the raw model loss at backward time due to badly defined gradients. One naive approach is to define a global variable which contains the loss and access it as is, however playing around with global variables is a bit edgy.
As such, is there a mecanism that allows to get the model loss from the ctx
in a few lines ? (i.e. not modifying the whole model)
I’m not sure I understand the question completely, but if you would like to grab the gradient at specific parameters, you could use hooks as:
model.layer.param.register_hook(my_hook)
and manipulate or print out the gradient.
Would this work or could you explain, what loss you would like to see?
Thanks for the reply, I’ll explain a bit more : let’s say I have a function such as
class MyRound(torch.autograd.Function):
@staticmethod
def forward(ctx,input):
return torch.round(input)
@staticmethod
def backward(ctx,grad):
return grad.clone()
In my use case the grad
at backward
is not well defined. Rather than operating on this badly defined grad
I can take a look at the model loss (e.g. cross entropy) and redefine a grad
that is useful. Can I access the cross-entropy loss (or whatever “root” loss from which the backpropagation started) from within backward
? (e.g. through the ctx
object) Result would look like :
@staticmethod
def backward(ctx,grad):
good_grad = ctx.root_loss
# and some operations to have the correct grad
return good_grad