Optimizer.step() disregards learning rate with multiple nn.Parameter()

I have a custom module with two sets of nn.Parameters():

class CustomLayer(nn.Module):
    def __init__(self):
        self.param1 = nn.Parameter(torch.rand())
        self.param2 = nn.Parameter(torch.rand())

later when I optimize with optimizer.step():

layer = CustomLayer()
optimizer = torch.optim.SGD(layer.parameters(), lr=0.5)
optimizer.step()

it doesn’t do param = param - 0.5*param.grad it instead calculates: param = param - param.grad

Hi,

Can you give a code sample that reproduces this issue please?

Hi @albanD, yes here is a sample code. I’m trying to implement separate loss in separate loss paper

import torch
import torch.nn as nn
torch.manual_seed(3275)


class SeparateLoss(nn.Module):
    def __init__(self, num_classes, feat_dim):
        super(SeparateLoss, self).__init__()
        self.centers = torch.randn(num_classes, feat_dim)
        self.intra_centers = nn.Parameter(self.centers.view_as(self.centers))
        self.inter_centers = nn.Parameter(self.centers.view_as(self.centers))
        self.cosine_similarity = torch.nn.CosineEmbeddingLoss()
        self.reset_params()

    def reset_params(self):
        nn.init.kaiming_normal_(self.centers.data)

    def forward(self, f, label):
        batch_size = label.size(0)
        y = label.new_ones(1, batch_size).float()
        s = self.centers.size(0)

        # calculating intra loss
        # ----------------------
        # fix the scales for c_intra gradients as in eq (6)
        counts = self.intra_centers.new_ones(s)
        ones = self.intra_centers.new_ones(batch_size)
        counts.scatter_add_(0, label.long(), ones)
        intra_hook = self.intra_centers.register_hook(lambda grad: grad/counts.unsqueeze(1))

        # eq (2)
        batch_centers = self.intra_centers.index_select(0, label.long())
        Lintra = (-1/2) * (-1 * self.cosine_similarity(f, batch_centers, y) + 2) + 1

        # calculating inter loss
        # ----------------------
        # fix the scales for c_inter gradients as in eq (7)
        inter_hook = self.inter_centers.register_hook(lambda grad: grad/2.0)
        dist = torch.nn.functional.cosine_similarity(
                self.inter_centers.unsqueeze(1).expand(s, s, self.inter_centers.size(1)),
                self.inter_centers.unsqueeze(0).expand(s, s, self.inter_centers.size(1)),
                dim=2
            )

        Linter = (((dist + 1).sum()) - (2.0 * s))/(2*s*(s - 1))

        # calculating separate loss
        Lsep = Lintra + Linter

        return Lsep, intra_hook, inter_hook


def unit_test():

    device = torch.device('cpu')
    m = 8  # number of samples
    k = 3  # number of classes
    feat_dim = 4

    x = torch.randn(m, feat_dim, requires_grad=True).to(device)
    target = torch.tensor([0, 0, 1, 0, 2, 2, 1, 0]).to(device)

    criterion = SeparateLoss(num_classes=k, feat_dim=feat_dim).to(device)
    center_optimizer = torch.optim.SGD(criterion.parameters(), lr=0.5)
    feature_optimizer = torch.optim.SGD({x}, lr=0.1)

    # epoch 1
    # -------
    lsep, intra_hook, inter_hook = criterion(x, target)
    print(lsep.item())

    # backward
    center_optimizer.zero_grad()
    lsep.backward()

    # add intra and inter gradients
    print(f'x.grad = {x.grad}')
    centers_grad = 0
    for param in criterion.parameters():
        centers_grad += param.grad
    # update centers grad
    for param in criterion.parameters():
        param.grad = centers_grad
        print(f'c.grad = {param.grad}')

    # update parameters
    feature_optimizer.step()
    center_optimizer.step()

    # release hooks
    intra_hook.remove()
    inter_hook.remove()


if __name__ == "__main__":
    unit_test()

so center_optimizer.step() should do param = param - 0.5*param.grad but when I debug the calcualted update is param = param - param.grad

You should replace the following:
(you should not give the same Tensor to different parameters. otherwise they have the same memory and so will always have the same value. And if you do an update on one then the other, then you do 2x the update)

        self.intra_centers = nn.Parameter(self.centers.view_as(self.centers))

by

        self.intra_centers = nn.Parameter(self.centers.clone())

And (you should not use .data)

        nn.init.kaiming_normal_(self.centers.data)

by

        nn.init.kaiming_normal_(self.centers)
1 Like

Thank you @albanD. This worked for me. Everything is fine until first epoch. I have problem in the second epoch though.
Following the question I asked yesterday and you answered:

my model has two sets of parameters a and b which require gradients. I have two different loss functions as:

loss1 = function(a, b)
loss2 = function(b)

and the total loss is

total_loss = loss1 + loss2

Assuming:

  • loss1 calculates b1.grad
  • loss2 calculates b2.grad,
    then the total b.grad calculated by total_loss.backward() is b.grad = b1.grad + b2.grad.

My goal is to modify b1.grad coming from loss1 and b2.grad comig from loss2 and then add them together as backward gradients for b. Currently when I do total_loss.bakward(), it already gives me the accumulated gradients for b.grad=b1.grad + b2.grad. How can I access and modify each individual b1.grad and b2.grad so the total_loss.backward() returns the modified and added b1.grad and b2.grad?
You told me to make intermediate gradients which I did in this code sample. However, I’m not sure if

centers_grad = 0
    for param in criterion.parameters():
        centers_grad += param.grad
    # update centers grad
    for param in criterion.parameters():
        param.grad = centers_grad
        print(f'c.grad = {param.grad}')

is doing the job of b.grad = b1.grad + b2.grad. It is ok for first epoch but doing the same thing in second epoch already adds the gradients for me. It seems like the centers_grad variable gets attached to something. Do you have any idea how to fix this?

Given that you modify .grad with things that requires_grad=True, I guess the next backward keep the graph from the previous iterations? Can you try setting all the .grad to None at the end of your iteration (instead of calling .zero_grad().