How to set different learning rate for weight and bias in one layer?

This toy example works.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer = nn.Linear(1, 1)
        self.layer.weight.data.fill_(1)
        self.layer.bias.data.fill_(1)

    def forward(self, x):
        return self.layer(x)

if __name__=="__main__":
    net = Net()
    optimizer = optim.Adam([
                {'params': net.layer.weight},
                {'params': net.layer.bias, 'lr': 0.01}
            ], lr=0.1, weight_decay=0.0001)
    out = net(Variable(torch.Tensor([[1]])))
    out.backward()
    optimizer.step()
    print("weight", net.layer.weight.data.numpy(), "grad", net.layer.weight.grad.data.numpy())
    print("bias", net.layer.bias.data.numpy(), "grad", net.layer.bias.grad.data.numpy())

Output is

weight [[ 0.90000004]] grad [[ 1.]]
bias [ 0.99000001] grad [ 1.]

As you can see, weight has been updated by ~0.1 * weight.grad and bias has been updated using ~0.01 * bias.grad.

The error you get suggests that you have asked the optimiser to optimise a Variable that isn’t a parameter of your model. But your partial code sample seems fine.

7 Likes