Hi, I want to initialize my
class(torch.autograd.Function), with some variables, however, I could not see any example like that. When I try to initialize the object with
forward pass cannot find those variables. Is it possible to do so?
Bellow is the only example I could find, where the class MyRelu doesn’t have any initial paramateres.
import torch class MyReLU(torch.autograd.Function): """ We can implement our own custom autograd Functions by subclassing torch.autograd.Function and implementing the forward and backward passes which operate on Tensors. """ #how can I initialize the class with some variables here? #I've tried this without success # def __init__(ctx, alpha=0.5, beta=0.5, gamma=1.0): # ctx.alpha = alpha # ctx.beta = beta # ctx.gama = gama @staticmethod def forward(ctx, input): """ In the forward pass we receive a Tensor containing the input and return a Tensor containing the output. ctx is a context object that can be used to stash information for backward computation. You can cache arbitrary objects for use in the backward pass using the ctx.save_for_backward method. """ ctx.save_for_backward(input) return input.clamp(min=0) @staticmethod def backward(ctx, grad_output): """ In the backward pass we receive a Tensor containing the gradient of the loss with respect to the output, and we need to compute the gradient of the loss with respect to the input. """ input, = ctx.saved_tensors grad_input = grad_output.clone() grad_input[input < 0] = 0 return grad_input