Customize forward and backward by subclassing autograd.Function

Hi, from I know I can define forward and backward functions like

def forward(self, input)
def backward(self, grad_output)

My question is can I change the interface of the two functions like

def forward(self, input, others) # need other inputs
def backward(self, others) # no need of grad_output since this is loss layer

Thanks !


there are two parts to this: It is OK to have several input arguments to forward.
This corresponds to returning a tuple of several input gradients in backward, possibly None if you don’t want to backprop into some of the inputs. (They are optional at the end in old-style autograd, but they become required in new-style autograd (master / pytorch >= 0.2).
However, the output_grad should always be there! To autograd, loss functions are not terribly special and if you add two losses you would still want that to work. (In the end, scalar variables are special in that you can call .backward() without arguments, implicitly using 1 as a starting point.) This will be a scalar variable if your result is a (scalar) loss.

As mentioned above, there has been a change in autograd to make taking derivatives of derivatives work. While I understand that old style currently in the tutorial still works, the new way is to have static methods for forward and backward and storing things in a context (ctx instead of self as first argument). If doing this, the .backward (but not the .forward) will operate on Variables instead of Tensors (use self.saved_variables and it will take care of that) and you need to return None for all non-differentiable arguments. These new-style Functions are then applied with Class.apply(inputs).
If you want to look at small examples, I can offer implicit functions and Cholesky decomposition, but they are not really commented in terms of how they use the autograd mechanics. Or you look at the pytorch source, many functions are defined in the files of .

Best regards


Thanks a lot Thomas !
1 So if I use the old way

def forward(self, input, others)
return loss

def backward(self, grad_output)
return grad_input, None

The backward must receive a single param of grad_output (and no others), even I do not use it, right ? And if I print out the grad_output value, it will be 1 if this is the final scalar loss layer, right ?

2 When do we have to write the backward function ourselves ? Is your implicit function an example for this ? Now I am just confused about an example that can the auto-backward mechanic work for it. I give a toy example in the following

class Model(nn.Module):
def forward(self, input, x, y): # input is of size N*D
tmp = input[x, y]
loss = tmp - 1
return loss

The real example is the loss will be calculated from different elements of input under different conditions. Each time the calculation will be different. So can I just leave it for the auto-backward to solve it ?


for 1, yes, that would be the general idea.

It is a good idea to multiply grad_output[0] to the input gradient you return to be consistent here, even if you envision that it usually is one. You might have a use case where you want to weight your loss with others and then you’ll be happy if it just works.

A final advice about implementing Functions that I forgot: There is a function autograd.gradcheck that will test your implementation (e.g. code cell 7 in the implicit function example) of the gradient against the numeric derivative. It’s always good to use that, but you might get false positives due to numerical errors (or extreme nonlinearity).

For 2:

You only write the backward function if you can’t express what you want to compute with what’s already there. So in your example you did the right thing by just writing a nn.Module and have the backward be calculated step by step automatically.
I rarely find myself missing many things.

In the two examples, implicit functions use an iteration to find the value and that is not good to backpropagate through naïvely. For the Cholesky decomposition, the situation is somewhat similar, although it is mainstream enough to eventually be included in pytorch properly (I guess, other libs also have it) after solving the performance issue in the backward.

Best regards