Hello,
I am writing a custom nn.Module class (as a layer) that calls an Autograd function. I want to write a custom backward function. The following example is from PyTorch - Extending.
My problem is, that I want to control a hyperparameter used in the backward of the LinearFunction(Function)
from outside - from the nn.Module class
. In the nn.Module class, we can use Function.apply in order to pass hyperparameters
from the nn.Module
to the forward method of the LinearFunction(Function)
. But the backward only takes the outputs.
How do I pass the values from the nn.Module
to the LinearFunction
?
class LinearFunction(Function):
@staticmethod
# bias is an optional argument for the forward method
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
#--------------------------------------------------------------
# I want to use some Hyperparameter in the backward of this Function
# but the backward function only takes grad_output!
#--------------------------------------------------------------
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
some_hyperparameter = 0.1
class Linear(nn.Module):
def __init__(self, input_features, output_features,some_hyperparameter,bias=True):
super(Linear, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
self.register_parameter('bias', None)
nn.init.uniform_(self.weight, -0.1, 0.1)
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1, 0.1)
self.some_hyperparameter = some_hyperparameter
def forward(self, input):
#--------------------------------------------------------------
# to do: pass the hyperparameter to LinearFunction.backward (?)
#--------------------------------------------------------------
return LinearFunction.apply(input, self.weight, self.bias)
def extra_repr(self):
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)