I tried the following standard custom function example from the documentation here:
class LinearFunction(Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
linear = LinearFunction.apply
b, p, q = 10, 5, 1
weight = nn.Parameter(torch.rand(q, p), requires_grad=True)
in_tensor = Variable(torch.rand(b, p), requires_grad=False)
out = linear(in_tensor, weight).mean()
out.backward()
When I run, it gives the following error:
RuntimeError: mm(): argument 'mat2' (position 1) must be Variable, not torch.FloatTensor
.
Can you spot where my mistake is?