Hi, I’m new to PyTorch. I implemented a custom function to perform Hadamard product of matrices as:
class HadamardProd(autograd.Function):
#@staticmethod
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = torch.mul(input, weight)
if bias is not None:
output += bias
return output
#@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_variables
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = torch.mul(grad_output, weight)
if ctx.needs_input_grad[1]:
grad_weight = torch.sum(grad_output * weight, 0)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = torch.sum(grad_output, 0)
if bias is not None:
return grad_input, grad_weight, grad_bias
else:
return grad_input, grad_weight
I used the autograd.gradcheck to check my gradient and it got true. But when I applied the corresponding layer of this function to my network, the loss.backward() got a TypeError that torch.mul received an invalid combination of arguments (torch.FloatTensor, Variable). I don’t how what’s wrong with the code.
And if I uncomment the line
#@staticmethod
I got an AttributeError that
'torch.FloatTensor' object has no atribute 'save_for_backward'
First, HadamardProd inherits from nn.Function but the examples all inherit from either torch.nn.Module or torch.autograd.Function. Changing that might help
It looks like your module is doing simple calculations using pytorch operations. If that is case I would suggest leaving out the backward function. PyTorch will figure it out on its own. This should work. I also changed the inplace add to an ordinary add, though this might not be entirely necessary.
class HadamardProd(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
output = torch.mul(input, weight)
if bias is not None:
output = output + bias
return output
That should help you along even though it doesn’t answer your specific question.
Sorry, I typed something wrong. HadamardProd does inherit from torch.autograd.Function in my original code. I don’t know whether leaving out the backward function would work, I 'll try it. Thanks for your suggestions.
Sorry, I typed something wrong. HadamardProd does inherit from torch.autograd.Function in my original code. Thanks for your suggestion. But leaving out the backward function doesn’t work. It raises a NotImplementedError.