Autograd_output in a simple custom linear layer

Dear Experts

I try to generate a simple custom linear layer as follows, but the prediction of the network is incorrect :frowning:
I tried hard for more than 2 weeks but I could not solve it. I hope someone help me.

class linearZ(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, weight):

        ctx.save_for_backward(input, weight)
        l2IN = input
        l2 = l2IN * weight
        return l2

    @staticmethod
    def backward(ctx, grad_output):

        learnign_rate = 0.01
        input, weight = ctx.saved_tensors
        grad_weight = input * grad_output
        weight = weight - learnign_rate * grad_weight
        net.linearZ.weight.data = weight
        return grad_weight, None


class MyLinearZ(nn.Module):
    def __init__(self):
        super(MyLinearZ, self).__init__()
        self.fn = linearZ.apply
        self.weight = nn.Parameter(torch.Tensor([[np.random.randn(1, 1) * 5.66]]))

    def forward(self, x):
        x = self.fn(x, self.weight)
        return x


class Net(nn.Module):
    def __init__(self, conv_weight):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, 2)
        self.pool = nn.MaxPool2d((1, 3))
        self.linearZ = MyLinearZ()
        self.conv1.weight = nn.Parameter(conv_weight)



    def forward(self, x):
        x = torch.atan(self.conv1(x))
        x = self.pool(x)
        x = torch.atan(self.linearZ(x))
        return x

Thanks before all

Iā€™m not sure how linearZ is supposed to work, as it seems like you would like to calculate the gradients in backward as well as manipulate some model parameters.
Could you explain the use case a bit as I think both steps (gradient calculation and parameter manipulation) should be done separately.
In case you would just like to reimplement a simple linear function, this code should work:

class linearZ(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight):
        ctx.save_for_backward(input, weight)
        output = input.mm(weight.t())
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        grad_input = grad_output.mm(weight)
        grad_weight = grad_output.t().mm(input)
        return grad_input, grad_weight


class MyLinearZ(nn.Module):
    def __init__(self):
        super(MyLinearZ, self).__init__()
        self.fn = linearZ.apply
        self.weight = nn.Parameter(torch.randn(1, 1) * 5.66)

    def forward(self, x):
        x = self.fn(x, self.weight)
        return x
3 Likes

Thanks ptrblk, you are the best :rose: