Parameters are not updated while training in custom layer

Dear all,

I want to create a custom 2D convolutional layer. The weights are not standard. I made a simplified dummy version of it. In this example (below), I have only one parameter (self.weight) and create conv2d kernels in the forward function (lweight) from this parameter, and call F.conv2d on it. It is working (i.e. the training is ongoing) but self.weight is not updated by the training process (I hypothesize that is because I do not call F.conv2d with self.weight but with lweight).
How can I link self.weight in order to make autograd update it?

Thank you.

# From _ConvNd
class TestConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, transposed=False, groups=1, bias=True, padding_mode='zeros'):

        super(TestConv, self).__init__()

        # store values
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.groups = groups
        self.padding_mode = padding_mode
        self._padding_repeated_twice = _repeat_tuple(_single(self.padding), 2)

        # single parameter
        self.weight = Parameter( torch.Tensor( 1, 1 ) )

        # import code
        # code.interact(local=locals())

        if bias:
            self.bias = Parameter(torch.Tensor(self.out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        # import code
        # code.interact(local=locals())

        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input):
        # Get value from single paramater
        WeightValue = self.weight.data.tolist()[0]
        print( WeightValue )

        # Create kernel tensors with single weight value
        _kernels = []
        for filter in range(self.out_channels):
            num_weight = 0
            _channel_kernels = []
            for inchan in range(self.in_channels):
                _channel_weight = np.full((self.kernel_size, self.kernel_size), WeightValue, np.float32)
                _channel_kernels.append(_channel_weight)
            _kernels.append( _channel_kernels )

        lweight = torch.tensor(_kernels, requires_grad=True).to(self.weight.device)

        return self._conv_forward(input, lweight)

Hi,

In general you should not use .data. It breaks the graph.
You should remove it and only use pytorch objects to allow the autograd to work fine.

Hi!

First, thank you for answering,

I tried your solution and a simpler one (creating kernels without using the weight tensor but a constant value) and the result remains the same: the value of self.weight is not changing along training.

May have you another advice? Should I try create a Function like explained here https://pytorch.org/docs/stable/notes/extending.html?

Best regards.

You won’t need a custom Function.

Can you share the version without the .data please?
Also make sure you don’t use .detach() inside the forward as that would break the graph as well.

Hi,

Here are the 3 options I tested to get the WeightValue from the tensor (the last was a test with a constant value)

        WeightValue = self.weight.tolist()[0]
        WeightValue = self.weight[0,0].clone().detach().cpu()
        WeightValue = 4.0

In none of them, the value of the weight tensor is updated.
Should I find a way to compute lweight directly as a tensor from self.weight value?

So:

  1. You convert to a python object, so the autograd cannot work. You need to only use pytorch tensors.
  2. You detach() so you explicitly ask for not gradient to flow back to the weight
  3. The value you use is independent of the weight, so it won’t have any gradient.

In all cases, you won’t have any gradient for self.weight. So it is expected that it won’t be updated.

Hi,

Ok, I will do all the computation with tensors. In my real application, it will be more tricky that in this simple example.

Hi,

So, thank you @albanD, it works know, I will be able to go ahead in my test.

If you cannot write everything with tensors, you will have to write a custom autograd.Function and write the backward pass yourself. It is usually tricky to do.
You can post here if you need any help.