Constructing parameter groups in pytorch

In the torch.optim documentation, it is stated that model parameters can be grouped and optimized with different optimization hyperparameters. It says that

For example, this is very useful when one wants to specify per-layer
learning rates:

optim.SGD([
                {'params': model.base.parameters()},
                {'params': model.classifier.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

This means that model.base’s parameters will use the default
learning rate of 1e-2, model.classifier’s parameters will use a
learning rate of 1e-3, and a momentum of 0.9 will be used for all
parameters.

I was wondering how to define such groups that have parameters() attribute. What came to my mind was something in the form of

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.base()
        self.classifier()

        self.relu = nn.ReLU()

    def base(self):
        self.fc1 = nn.Linear(1, 512)
        self.fc2 = nn.Linear(512, 264)

    def classifier(self):
        self.fc3 = nn.Linear(264, 128)
        self.fc4 = nn.Linear(128, 964)

    def forward(self, y0):

        y1 = self.relu(self.fc1(y0))
        y2 = self.relu(self.fc2(y1))
        y3 = self.relu(self.fc3(y2))

        return self.fc4(y3)

How should I modify the snippet above to be able to get model.base.parameters()? Is the only way to define a nn.ParameterList and explicitly add weights and biases of the desired layers to that list? What is the best practice? Is there any other ways other than using nn.Sequential?

While structuring it (with separate nn.Module classes for base and classifier) would look more natural to me in your example, you can pass any list/iterator to the optimizers for the parameters, not just ParameterLists / Sequential.

For example:

some_params = []
some_params.extend(fc1.parameters())
some_params.extend(fc2.parameters())

other_params = [p for p in model.parameters() if p not in set(some_parameters)]

You could also filter by name with [p for n, p in model.named_parameters() if n.endswith('bias')] or so.

Best regards

Thomas

1 Like