Loss is changing but the weights are not

#custom initiliazation of weight
self.hidden_layer_1.weight = torch.nn.Parameter(torch.from_numpy(weights))
a = self.hidden_layer_1.weight
num_epochs=10
for epoch in range(num_epochs):
            print('Epoch {}'.format(epoch))
            torch.autograd.set_detect_anomaly(True)
            epoch_loss_train = []
            epoch_loss_validation = []
            for i, (x,y) in enumerate(zip(feature_trainloader,label_trainloader), 0):
                optimizer.zero_grad()   
                output = self.forward(x)
                target = y
                loss = self.CrossEntropyLoss(output, target)
                epoch_loss_train.append(loss.item())
                y_true_train = torch.cat((y_true_train, target))
                _, pred_class = torch.max(output, dim=1)
                y_pred_train = torch.cat((y_pred_train, pred_class))
                loss.backward()
                optimizer.step() 
b = self.hidden_layer_1.weight
print(torch.equal(a, b)) #True

I don’t understand how weights are the same after 10 epochs even though the loss is changing

Could you share a colab notebook which can reproduce the issue?
A minimal implementation would work.

As a sanity check, can you actually save (e.g., with something like (.detatch().clone()) (or print) the values before and after to compare them? I have a suspicion that a and b are the same reference so they are both updated as the model is being trained.

Hi Jpj!

At this point, you (presumably) have already created your optimizer, so
your optimizer contains a collection of Parameters that it will be updating.

Because you have overwritten weight with a new Parameter (rather
than having overwritten weight.data with a new tensor) – and done so
after creating your optimizer – your optimizer does not contain the new
Parameter that is used (via the forward pass) to calculate the loss.
(The Parameter that is modified by the optimizer-- or would be if it had
a non-trivial gradient – is the old Parameter that you have, in a sense, “hidden.”)

So self.hidden_layer_1.weight doesn’t change, because it is not in
the optimizer’s list of Parameters to update.
[/quote]

Why the loss is changing even though self.hidden_layer_1.weight
is not: I will assume that self.hidden_layer_1 is a Linear. A Linear
has both a weight and a bias. You haven’t overwritten bias, so it is
still updated by your optimizer, and still affects the loss you calculate.

Best.

K. Frank

2 Likes

Note that even after fixing the optimizer issue, the weights should not be compared using references.
For example, this code snippet will report the “saved” weights being the same because both references refer to the same underlying tensor. Copying the weights before training will produce the expected results:

import torch

a = torch.nn.Linear(128,1)
b = torch.nn.Sigmoid()
optimizer = torch.optim.SGD(a.parameters(), 1e-3)
criterion = torch.nn.BCELoss()
saved = a.weight
detached = a.weight.detach().clone()
for i in range(0, 512):
    inp = torch.randn((1, 128))
    label = float(torch.sum(inp[64:]) > 0)
    label = torch.tensor([[label]])
    output = b(a(inp))
    loss = criterion(output, label)
    print(loss)
    loss.backward()
    optimizer.grad_ = None
    optimizer.step()
print(torch.allclose(saved, a.weight))
print(saved == a.weight)
print(torch.allclose(detached, a.weight))
print(detached == a.weight)

Thanks for the reply.

I do create my optimizer before this statement, you are correct.
So if I understand correctly, self.hidden_layer_1.weight.data = my_custom_weights should fix it?

Hi Jpj!

I think that will fix it.

My preferred idiom (not that I know what I’m doing) would be:

with torch.no_grad():
    self.hidden_layer_1.weight.data.copy_(my_custom_weights) 

Best.

K. Frank