Model weights not getting updated when using autocast

Hello All,
I am trying to implement develop a model to generate masks on the Kvasir-SEG Dataset(The Kvasir-SEG Dataset) and using a mixture of focal and Dice loss which is as follows -

class DiceFocalLoss(nn.Module):
    def __init__(self, weight=None, size_average=True, alpha=0.25, gamma=3, smooth=1):
        super(DiceFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.smooth = smooth

    def forward(self, inputs, targets):
        # comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = F.sigmoid(inputs)

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)

        BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = self.alpha * (1-BCE_EXP)**self.gamma * BCE

        Dice_BCE = focal_loss + dice_loss

        return Dice_BCE

and while debugging the model, I take a single (image, mask) pair from my dataset and train only on it but surprisingly while using autocast the loss change but rather remains constant indicating the model weights are not being updated. I am able to recreate the issue with a simple 2 layer convolution model as well -
the toy model is defined as -

model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1),
                       nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0))
model = model.cuda().train()

and my training loop is -

optim = torch.optim.SGD(model.parameters(), lr=1e-5)

for i in range(10):
    with autocast():
        out = model(image)
        l = loss(out, mask)
    scaler.scale(l).backward()
    scaler.step(optim)
    scaler.update()
    print(l)

loss values without autocast -

tensor(2.6734, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.0208, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.3117, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.8833, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.7073, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.6678, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.6844, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.7187, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.7563, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.7941, device='cuda:0', grad_fn=<AddBackward0>)

Loss values with autocast -

tensor(8.9766, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8.9766, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8.9766, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8.9766, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8.9766, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8.9766, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8.9766, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8.9766, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8.9766, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8.9766, device='cuda:0', grad_fn=<AddBackward0>)

As you can see, there is no change no matter what which implies the model weights are not getting updated.
Could anyone please point out what is going wrong ?
TIA

I’m not sure, how the FP32 code snippet is used, but note that you are not zeroing out the gradients.
However, even after adding optim.zero_grad() neither the FP32 nor the AMP model converges nicely:

model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1),
                       nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0))
model = model.cuda().train()

scaler = torch.cuda.amp.GradScaler()

optim = torch.optim.SGD(model.parameters(), lr=1e-5)
image = torch.randn(1, 3, 24, 24).cuda()
mask = torch.empty(1, 1, 24, 24).uniform_().cuda()
loss = nn.BCEWithLogitsLoss()

use_amp = False

for i in range(1000):
    optim.zero_grad()
    with torch.cuda.amp.autocast(enabled=use_amp):
        out = model(image)
        l = loss(out, mask)
    if use_amp:
        scaler.scale(l).backward()
        scaler.step(optim)
        scaler.update()
    else:
        l.backward()
        optim.step()
        
    print(l)

Thanks a lot for replying @ptrblck . Sorry I did not notice that I missed the zero_grad. I missed it in my original model too and freaked out :sweat_smile:. On my single data point, the model converges decently in both modes.
But even when not using zero_grad, why did FP32 loss decrease but FP16 remain constant(even though model weights change). The loss should at least vary right(since the weights are not the same after optim.step())?

Yes, I would assume that both models converge. Could you post an executable code snippet using e.g. random input tensors, which would show the different behaviors?

This is the step up I use -

scaler = GradScaler()
loss = DiceFocalLoss()
for _, data in enumerate(dataloader):
    break

image = torch.randn((1, 3, 640, 640)).cuda()
mask = torch.randint(0, 2, (1, 640, 640)).cuda().float()

model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1),
                       nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0))
model = model.cuda().train()

optim = torch.optim.Adam(model.parameters(), lr=1e-3)
use_amp = False
loss_list = []

for i in range(10):
    #optim.zero_grad()
    with torch.cuda.amp.autocast(enabled=use_amp):
        out = model(image)
        l = loss(out, mask)
    if use_amp:
        scaler.scale(l).backward()
        scaler.step(optim)
        scaler.update()
    else:
        l.backward()
        optim.step()

    print(l)

print("=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-==-==-=-==-==-=-==-==-=-=-==-")
model1  = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1),
                       nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0))

use_amp = True
model1.cuda().train()
for i in range(10):
    #optim.zero_grad()
    with torch.cuda.amp.autocast(enabled=use_amp):
        out = model1(image)
        l = loss(out, mask)
    if use_amp:
        scaler.scale(l).backward()
        scaler.step(optim)
        scaler.update()
    else:
        l.backward()
        optim.step()

    print(l)

The output is -

tensor(2.2960, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.2791, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.2631, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.2475, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.2321, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.2169, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.2018, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.1868, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.1719, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.1571, device='cuda:0', grad_fn=<AddBackward0>)
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-==-==-=-==-==-=-==-==-=-=-==-
tensor(2.8875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.8875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.8875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.8875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.8875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.8875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.8875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.8875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.8875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.8875, device='cuda:0', grad_fn=<AddBackward0>)

Process finished with exit code 0

The top one is FP32 and the below one FP16.
There is no change in losses in FP16. One should expect the same behavior right? like losses might be different but it should decrease. The loss is Dice Focal loss as described above.

In your code snippet you are not recreating an optimizer for the AMP run, so the model won’t be updated.

Generally that’s the assumption. However, from my experience this assumption is valid, if the FP32 training is “healthy”, i.e. if the loss/activations are not blowing up (i.e. the model diverges), as the AMP training might not be able to recover from this state.

After adding the missing optimizer, note that both runs indeed “change”, but the overall training might not be well defined.

E.g. in some runs the FP32 loss seems to get negative values and explodes afterwards:

tensor(-30.6151, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-35.7750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-42.5553, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-52.2006, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-67.1675, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-93.4945, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-150.9680, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-356.5783, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2775.4045, device='cuda:0', grad_fn=<AddBackward0>)
tensor(369.3389, device='cuda:0', grad_fn=<AddBackward0>)

in others both losses seem to be reduced:

FP32
tensor(0.6024, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.6009, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5995, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5981, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5967, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5953, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5939, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5924, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5910, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5896, device='cuda:0', grad_fn=<AddBackward0>)
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-==-==-=-==-==-=-==-==-=-=-==-
AMP
tensor(0.5286, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5277, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5268, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5258, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5249, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5239, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5230, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5220, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5211, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5202, device='cuda:0', grad_fn=<AddBackward0>)

while others break with a negative loss again:

FP32
tensor(5.8274, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5.6770, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5.5385, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5.4073, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5.2816, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5.1607, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5.0440, device='cuda:0', grad_fn=<AddBackward0>)
tensor(4.9311, device='cuda:0', grad_fn=<AddBackward0>)
tensor(4.8220, device='cuda:0', grad_fn=<AddBackward0>)
tensor(4.7163, device='cuda:0', grad_fn=<AddBackward0>)
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-==-==-=-==-==-=-==-==-=-=-==-
AMP
tensor(-9.6253, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-9.6253, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-9.6253, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-9.6253, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-9.6253, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-9.6253, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-9.6253, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-9.6253, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-9.6253, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-9.6253, device='cuda:0', grad_fn=<AddBackward0>)

Is a negative loss expected in your case?

I am building a model for semantic segmentation, and the loss function is a combination of focal and Dice loss. Dice loss cannot be negative and neither can focal loss(as BCE would be between 0 and 1) so I dont think it should be negative(right?).

The main issue was that I was not observing consistent behavior between FP32 and FP16 which has been resolved now,(due to that optimizer.zero_grad. ). If I still get the same loss, It can be attributed to the network itself(like always giving output as zero due to dying relu. )
Thanks a lot :grinning:

Update - There is one thing though, and I am not able to justify this - In the following code for every epoch I check whether the weights are same even after optim.step
The code is -

image = torch.randn((1, 3, 640, 640)).cuda()
mask = torch.randint(0, 2, (1, 640, 640)).cuda().float()

model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1),
                       nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0))
model = model.cuda().train()

optim = torch.optim.Adam(model.parameters(), lr=1e-3)
use_amp = True
loss_list = []

for i in range(10):
    #optim.zero_grad()
    a1 = list(model.parameters())[0].clone()
    a2 = list(model.parameters())[1].clone()
    a3 = list(model.parameters())[2].clone()
    a4 = list(model.parameters())[3].clone()
    with torch.cuda.amp.autocast(enabled=use_amp):
        out = model(image)
        l = loss(out, mask)
    if use_amp:
        scaler.scale(l).backward()
        scaler.step(optim)
        scaler.update()
    else:
        l.backward()
        optim.step()
    print(l)
    b1 = list(model.parameters())[0].clone()
    b2 = list(model.parameters())[1].clone()
    b3 = list(model.parameters())[2].clone()
    b4 = list(model.parameters())[3].clone()
    print(torch.equal(a1, b1), torch.equal(a2, b2), torch.equal(a3, b3), torch.equal(a4, b4))
    print("-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=")

And interestingly the output is -

tensor(4.0811, device='cuda:0', grad_fn=<AddBackward0>)
True True True True
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.0811, device='cuda:0', grad_fn=<AddBackward0>)
True True True True
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.0811, device='cuda:0', grad_fn=<AddBackward0>)
True True True True
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.0811, device='cuda:0', grad_fn=<AddBackward0>)
True True True True
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.0811, device='cuda:0', grad_fn=<AddBackward0>)
True True True True
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.0811, device='cuda:0', grad_fn=<AddBackward0>)
True True True True
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.0811, device='cuda:0', grad_fn=<AddBackward0>)
True True True True
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.0811, device='cuda:0', grad_fn=<AddBackward0>)
True True True True
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.0811, device='cuda:0', grad_fn=<AddBackward0>)
True True True True
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.0811, device='cuda:0', grad_fn=<AddBackward0>)
True True True True
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

and the output with use_amp=False is -

tensor(4.9359, device='cuda:0', grad_fn=<AddBackward0>)
False False False False
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.8451, device='cuda:0', grad_fn=<AddBackward0>)
False False False False
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.7606, device='cuda:0', grad_fn=<AddBackward0>)
False False False False
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.6795, device='cuda:0', grad_fn=<AddBackward0>)
False False False False
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.6010, device='cuda:0', grad_fn=<AddBackward0>)
False False False False
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.5246, device='cuda:0', grad_fn=<AddBackward0>)
False False False False
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.4501, device='cuda:0', grad_fn=<AddBackward0>)
False False False False
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.3773, device='cuda:0', grad_fn=<AddBackward0>)
False False False False
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.3061, device='cuda:0', grad_fn=<AddBackward0>)
False False False False
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
tensor(4.2365, device='cuda:0', grad_fn=<AddBackward0>)
False False False False
-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

Process finished with exit code 0

Now irrespective of whatever the loss might be, same, negative or for that matter even nan weights should not be the same. I am even using clone so as to create a copy and compare those. And this behavior is not consistent in the sense that keep rerunning the code above and sometimes the weights are equal and sometimes they are not(when in fp16 mode) . This should not be happening rigth?

Did you maybe forget to clone the parameters of model1 (used in the AMP run) and might be comparing the parameters of model (used in the FP32 run) during the AMP run?
If I copy-paste your code and adapt the AMP model name, I get valid updates (seen via False...).

I wrote a fresh code to check whether the models are being updated or not. I First use_amp=True and then run the script and again (I am simply doing python3 main.py from the terminal). I did this for 10 times and found that only 3 times, I get (False, False …).

When I do not use amp I always get valid updates. This is the behaviour on my machine(Though I dont think it should matter). I even deleted the cached files just in case.

If you keep rerunning(not a loop, a fresh run of the script every time) that piece code, Even you might get True True True True at least once. Are you seeing such behaviour on your machine?

Yes, I also see it sometimes in case I’m hitting the invalid loss values.
In that case the gradient also seems to blow up, the scaler will skip the optimization step, and reduce it’s scaling factor.
You could manually check it via:

        s0 = scaler.get_scale()
        scaler.update()
        s1 = scaler.get_scale()
        if s0 != s1:
            print('skipped')
1 Like

Oh okay, yes that makes sense, thanks a lot for your time @ptrblck blck