Extending module and function pytorch

(MF) #1

Hello All,

I’m trying to extend pytorch by adding a new module and function.

class Linear(torch.autograd.Function):
    @staticmethod
    def forward(self, input, weight_Pos, weight_Neg, bias=None):
        self.save_for_backward(input,weight_Pos, weight_Neg, bias)
        weight =weight_Pos - weight_Neg
        if input.dim() == 2 and bias is not None:
            # fused op is marginally faster
            ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
        else:
            output = input.matmul(weight.t())
            if bias is not None:
                output += torch.jit._unwrap_optional(bias)
            ret = output
        return ret    
    
    @staticmethod
    def backward(self, grad_output):
        input, weight_Pos, weight_Neg, bias = self.saved_tensors
        grad_input = grad_weight_Pos= grad_weight_Neg = grad_bias = None
        if self.needs_input_grad[0]:
            grad_input = grad_output.mm(weight_Pos)
        if self.needs_input_grad[1]:
            grad_weight_Pos= grad_output.t().mm(input)
        if self.needs_input_grad[2]:
            grad_weight_Neg = -1*grad_output.t().mm(input)
        if bias is not None and self.needs_input_grad[3]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight_Pos, grad_weight_Neg, grad_bias

But, I’m getting an error and I don’t know the problem.

<ipython-input-205-f9f4df5e472f> in backward(self, grad_output)
     29             grad_input = grad_output.mm(weight_Pos)
     30         if self.needs_input_grad[1]:
---> 31             grad_weight_Pos= grad_output.t().mm(input)
     32         if self.needs_input_grad[2]:
     33             grad_weight_Neg = -1*grad_output.t().mm(input)

RuntimeError: t() expects a 2D tensor, but self is 1D