Weight of linear updated using backprop, but pytorch still make calculation using old weight

Hi! Im having a peculiar issue, I’m doing an inner optimization steps using this code:

        inner_model = NN(param.detach(),self.kwargs).to(self.device).train()
        inner_opt = torch.optim.SGD(inner_model.model.parameters(), lr=100, momentum=.7, weight_decay=1e-6)
       for opt_step in range(20):
           out_inner = inner_model(images)
           loss = torch.sum(mask * F.binary_cross_entropy_with_logits(out_inner.squeeze(),inner_cur_lbl[idx].float(),                                                          reduction='none'))

There are few important things to point out:
1.The grads are good.
2. The weight are updating, I check them using:


  1. The output a.k.a: out_inner is the same(!) as before the updates!!
    So I entered nn.Linear forward method:

def forward(self, input: Tensor) → Tensor:
return F.linear(input, self.weight, self.bias)

And checked the self.weight, they are exactly as expected (i.e they changed exactly as I saw using section 2), hence I tried to run:

return F.linear(input, self.weight, self.bias)

But I got the same unchanged results, thus I did another experiment I checked:

torch.mm(input,self.weight.T) #I have no bias

And got different results!! results that shows that self.weight has changed!,i.e

torch.mm(input,self.weight.T) !=F.linear(input, self.weight, self.bias)

I think that something I did non properly made the F.linear use an old copy of the weights, once again , this is inspite self.weight has changed.
The final experiment and maybe the most important one was running:

F.linear(input, self.weight.clone(), self.bias)

This line calculated the same result as the matmul, but not as:

F.linear(input, self.weight, self.bias)

I know its not an easy post to follow, I tried to do my best to illustrate, does anyone has any idea what is going on?

Edit-Few more hints:
1.Switching from DDP to 1 GPU didn’t help.
2. This code is under training_step method under the pytorch lightning package
3. When switching to CPU it works!!!
This is even more puzzeling! what Im missing?

I doubt it and guess other parts of your models might “saturate” the outputs.

Could you post a minimal and executable code snippet reproducing the issue?

Ill try to work on Isolating the code for a running example.

About the “saturation”, it was the first thing I checked. But the experiments I did shows nothing like that, did you see them? the linear not using the self.weight, when using the self.weight with matmul it return the correct result. when using inside Linear function in torch code with:

F.linear(input, self.weight.clone(), self.bias)

It returns the correct result, how is it possible that this is an output saturation issue?