I have a question about how to define a custom module in pytorch. In my case, I need to define my own forward and backward function. Some example codes show that the autograd.Function subclass should have forward and backward parts, and the autograd.Module subclass should have init and forward parts. but in this way, I found the custom backward function has not been called. Can you give me some advice about this question ?
Thank you !
In order to call custom backward passes in you custom nn.Module
, you should define your own autograd.Function
s an incorporate them in your nn.Module
. Here’s a minimal dummy example:
import torch
import torch.autograd as autograd
import torch.nn as nn
class MyFun(torch.autograd.Function):
def forward(self, inp):
return inp
def backward(self, grad_out):
grad_input = grad_out.clone()
print('Custom backward called!')
return grad_input
class MyMod(nn.Module):
def forward(self, x):
return MyFun()(x)
mod = MyMod()
y = autograd.Variable(torch.randn(1), requires_grad=True)
z = mod(y)
z.backward()
running this script outputs
$ python dummy.py
Custom backward called!
Find more information in the following tutorial:
http://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-defining-new-autograd-functions
Yes, it works, thank you
Hi jytug: If in addition, I want to pass some other information backward layer by layer besides the output_grad. what should I do ? There is no need to add anything on forward. I try to use the following code while doesn’t work. Thanks a lot !
import torch
import torch.autograd as autograd
import torch.nn as nn
class MyFun(torch.autograd.Function):
def forward(self, inp):
return inp
def backward(self, grad_out, P):
grad_input = grad_out.clone()
print('Custom backward called!')
return grad_input, P-1
class MyMod(nn.Module):
def forward(self, x):
return MyFun()(x)
mod = nn.Sequential(MyMod(), MyMod())
y = autograd.Variable(torch.randn(1), requires_grad=True)
z = mod(y)
P = torch.ones((1,1))
z.backward(P)
In my environment with PyTorch 0.2.0, it seems that PyTorch is smart enough to omit calling the backward function for the example forward function (i.e., the identity function).
However, the backward function is called for the forward function with “inp+0.0”