For the purposes of fundamental research I need to modify the nn.Linear function class (both forward and backward functions). I require an exact version of the LinearFunction class in python, that does not suffer from miscellaneous errors. The example provided of a custom Linear class throws a broadcasting error when compared to the default C++ implementation (nn.functional.linear - torch._C._nn.linear);
https://pytorch.org/docs/master/notes/extending.html
class LinearFunction(Function):
...
output = input.mm(weight.t()) -> "RuntimeError: self must be a matrix"
You could check the backend implementation for linear
here and use similar logic in your custom code.
Thanks I had a look at Linear.cpp but was not confident converting it. Here is a version of the extending LinearFunction example function that supports 3D matrix multiplication (replaces instances of mm
with matmul
, t
with swapaxes
);
class LinearFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(ctx, input, weight, bias):
# The forward pass can use ctx.
ctx.save_for_backward(input, weight, bias)
output = input.matmul(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.swapaxes(-1, -2).matmul(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
Note nn.linear
apparently supports 1-dimensional input, in which case the following check may be required;
if(len(grad_output.shape) > 1):
grad_outputT = grad_output.transpose(-1, -2)
else:
grad_outputT = grad_output
grad_weight = grad_outputT.matmul(input)