I am trying to migrate the below chainer code to pytorch.
I was looking at the below chainer documentation
https://docs.chainer.org/en/v7.8.1/reference/generated/chainer.FunctionNode.html
Class CustomErrorFunction(chainer.function_node.FunctionNode):
def forward(self, inputs):
self.retain_inputs((0,1,2))
loss=input[1]
return loss
def backward(self, indexes,gy):
x,y,z = self.get_retained_inputs()
gy0 = chainer.functions.broadcast_to(gy[0], y.shape)
return gy0*z, None, None
I want to know how to create this customer function in pytorch using torch.autograd.Function
Are their any function in pytorch similar to these chainer functions i.e. retain_inputs() and get_retained_inputs()
Any help would be greatly apprciated