Initialize parameters in a torch.autograd.Function

Hi, I want to initialize my class(torch.autograd.Function), with some variables, however, I could not see any example like that. When I try to initialize the object with __init__(), the forward pass cannot find those variables. Is it possible to do so?

Bellow is the only example I could find, where the class MyRelu doesn’t have any initial paramateres.

import torch

class MyReLU(torch.autograd.Function):
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
   #how can I initialize the class with some variables here?
   #I've tried this without success
   #  def __init__(ctx, alpha=0.5, beta=0.5, gamma=1.0):
   #    ctx.alpha = alpha
   #    ctx.beta = beta
   #    ctx.gama = gama
    def forward(ctx, input):
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        return input.clamp(min=0)

    def backward(ctx, grad_output):
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

1 Like


No this is not possible.
We don’t really use Function as class but more as a convenient way to store the pair of functions for forward and backward.
You can pass these arguments directly to the forward though. (and return None for their grad in the backward).

1 Like