Create custom autograd functions in Pytorch

I am trying to migrate the below chainer code to pytorch.
I was looking at the below chainer documentation
https://docs.chainer.org/en/v7.8.1/reference/generated/chainer.FunctionNode.html

Class CustomErrorFunction(chainer.function_node.FunctionNode):

         def forward(self, inputs):
                self.retain_inputs((0,1,2))
                loss=input[1]
                return loss

         def backward(self, indexes,gy):
                x,y,z = self.get_retained_inputs()
                gy0 = chainer.functions.broadcast_to(gy[0], y.shape)
                return gy0*z, None, None

I want to know how to create this customer function in pytorch using torch.autograd.Function

Are their any function in pytorch similar to these chainer functions i.e. retain_inputs() and get_retained_inputs()

Any help would be greatly apprciated

I’m not sure how retain_inputs works exactly, but would guess it’s the equivalent to ctx.save_for_backward as described in this tutorial.

1 Like