PyTorch - Passing Hyperparameters to backprop of autograd.Function

Hello,
I am writing a custom nn.Module class (as a layer) that calls an Autograd function. I want to write a custom backward function. The following example is from PyTorch - Extending.

My problem is, that I want to control a hyperparameter used in the backward of the LinearFunction(Function) from outside - from the nn.Module class. In the nn.Module class, we can use Function.apply in order to pass hyperparameters
from the nn.Module to the forward method of the LinearFunction(Function). But the backward only takes the outputs.

How do I pass the values from the nn.Module to the LinearFunction?

class LinearFunction(Function):
    @staticmethod
    # bias is an optional argument for the forward method
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output
    @staticmethod
    def backward(ctx, grad_output):
        
        #--------------------------------------------------------------
        # I want to use some Hyperparameter in the backward of this Function
        # but the backward function only takes grad_output!
        #--------------------------------------------------------------
        
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        return grad_input, grad_weight, grad_bias

some_hyperparameter = 0.1

class Linear(nn.Module):
    def __init__(self, input_features, output_features,some_hyperparameter,bias=True):
        super(Linear, self).__init__()
        self.input_features  = input_features
        self.output_features = output_features
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            self.register_parameter('bias', None)

        nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -0.1, 0.1)
            
        self.some_hyperparameter = some_hyperparameter

    def forward(self, input):
        #--------------------------------------------------------------
        # to do: pass the hyperparameter to LinearFunction.backward (?)
        #--------------------------------------------------------------
        return LinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None
        )

The typical way to do this (i.e. PyTorch uses it internally, too) is to pass it to the forward and store it as a ctx attribute (ctx.my_hyperparameter = ...) and read it as ctx.my_hyperparameterin the backward. Only input/output tensors want to go through saved tensors, so you should be safe here.

Best regards

Thomas

1 Like

Could you also tell me how to pass those parameters
to the forward function without having to calculate the backprop for them?

When I pass 5 parameters to the forward function including 2 that are non-trainable hyperparameters, Pytorch asks me to return 5 gradients.

Return None as gradient.

1 Like