Is it possible to set learning rates for each channel of the weight in a conv layer?

Hello,
I wonder if it is possible to set learning rates for each channel of the weight in a certain conv layer? I wrote an example but it would raise an error.

# e.g.
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = nn.Conv2d(3, 8, kernel_size=1, padding=0, stride=1, bias=False)

    def forward(self, x):
        return self.conv(x)

model = Model()
print(model.conv.weight.shape) # torch.Size([8, 3, 1, 1])

lr = [0.01* i for i in range(1, model.conv.weight.shape[0] + 1)] # [0.01, 0.02, ..., 0.08]

torch.optim.Adam(([{'params': p, 'lr': l} for p,l in zip(model.conv.weight, lr)])
# ValueError: can't optimize a non-leaf Tensor

Thanks!

A possible solution is to split the weight parameter in a layer (e.g. nn.Conv2d) into multiple parameters,
and then concat the outputs together.

For example, for a layer with 2 input channels and 4 output channels, a pseudo code works as bellow:

class MyConv(nn.Module):

    def __init__(self, channel_in, channel_out, kernel_size):
        self.weight1 = Parameter(torch.zeros(1, channel_in, kernel_size, kernel_size))
        self.weight2 = Parameter(torch.zeros(1, channel_in, kernel_size, kernel_size))
        ...
        self.weight4 = Parameter(torch.zeros(1, channel_in, kernel_size, kernel_size))

        ...
    def forward(self, input):
        output1 = torch.nn.functional.conv2d(input, self.weight1, ...)
        ...
        output4 = torch.nn.functional.conv2d(input, self.weight1)
        return torch.cat((output1, ..., output4), dim=1)

And then you can apply channel-wise learning rate tuning by adjusting the learning rates of parameter1, …, parameter4.

1 Like

Compared with splitting the weight, I prefer to modify “sgd.py” with a few lines. For example, you can change input settings into lr: Union[float, Tensor], and change this line: torch._foreach_add_(device_params, device_grads, alpha=-lr) by add the judgement sentence and the index of ‘lr’.