How to call the backward function of a custom module

I have a question about how to define a custom module in pytorch. In my case, I need to define my own forward and backward function. Some example codes show that the autograd.Function subclass should have forward and backward parts, and the autograd.Module subclass should have init and forward parts. but in this way, I found the custom backward function has not been called. Can you give me some advice about this question ?
Thank you !

In order to call custom backward passes in you custom nn.Module, you should define your own autograd.Functions an incorporate them in your nn.Module. Here’s a minimal dummy example:

import torch
import torch.autograd as autograd
import torch.nn as nn

class MyFun(torch.autograd.Function):
    def forward(self, inp):
        return inp

    def backward(self, grad_out):
        grad_input = grad_out.clone()
        print('Custom backward called!')
        return grad_input

class MyMod(nn.Module):
    def forward(self, x):
        return MyFun()(x)

mod = MyMod()

y = autograd.Variable(torch.randn(1), requires_grad=True)
z = mod(y)
z.backward()

running this script outputs

$ python dummy.py
Custom backward called!

Find more information in the following tutorial:
http://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-defining-new-autograd-functions

7 Likes

Yes, it works, thank you :grinning:

Hi jytug: If in addition, I want to pass some other information backward layer by layer besides the output_grad. what should I do ? There is no need to add anything on forward. I try to use the following code while doesn’t work. Thanks a lot !

import torch
import torch.autograd as autograd
import torch.nn as nn

class MyFun(torch.autograd.Function):
    def forward(self, inp):
        return inp

def backward(self, grad_out, P):
    grad_input = grad_out.clone()
    print('Custom backward called!')
    return grad_input, P-1

class MyMod(nn.Module):
    def forward(self, x):
        return MyFun()(x)

mod = nn.Sequential(MyMod(), MyMod())

y = autograd.Variable(torch.randn(1), requires_grad=True)
z = mod(y)
P = torch.ones((1,1))
z.backward(P)

In my environment with PyTorch 0.2.0, it seems that PyTorch is smart enough to omit calling the backward function for the example forward function (i.e., the identity function).
However, the backward function is called for the forward function with “inp+0.0” :smile: