Hello All,
I’m trying to extend pytorch by adding a new module and function.
class Linear(torch.autograd.Function):
@staticmethod
def forward(self, input, weight_Pos, weight_Neg, bias=None):
self.save_for_backward(input,weight_Pos, weight_Neg, bias)
weight =weight_Pos - weight_Neg
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
else:
output = input.matmul(weight.t())
if bias is not None:
output += torch.jit._unwrap_optional(bias)
ret = output
return ret
@staticmethod
def backward(self, grad_output):
input, weight_Pos, weight_Neg, bias = self.saved_tensors
grad_input = grad_weight_Pos= grad_weight_Neg = grad_bias = None
if self.needs_input_grad[0]:
grad_input = grad_output.mm(weight_Pos)
if self.needs_input_grad[1]:
grad_weight_Pos= grad_output.t().mm(input)
if self.needs_input_grad[2]:
grad_weight_Neg = -1*grad_output.t().mm(input)
if bias is not None and self.needs_input_grad[3]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight_Pos, grad_weight_Neg, grad_bias
But, I’m getting an error and I don’t know the problem.
<ipython-input-205-f9f4df5e472f> in backward(self, grad_output)
29 grad_input = grad_output.mm(weight_Pos)
30 if self.needs_input_grad[1]:
---> 31 grad_weight_Pos= grad_output.t().mm(input)
32 if self.needs_input_grad[2]:
33 grad_weight_Neg = -1*grad_output.t().mm(input)
RuntimeError: t() expects a 2D tensor, but self is 1D