Loss not getting printed in when manual optimization is turned on

Bug description

I ported an simple toy example in pytorch lightning and found that when manual optimization is turned on (to implement a custom optimizer), the loss doesn’t get printed. Not sure if I am missing to set some something.

Any help is appreciated.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

Baseline toy example (works fine and prints loss correctly):


import pytorch_lightning as pl
import numpy as np
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from ptadamw import AdamW
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn


class SimpleDataset(Dataset):
    def __init__(self):
        X = np.arange(10000)
        y = X * 2
        X = [[_] for _ in X]
        y = [[_] for _ in y]
        self.X = torch.Tensor(X)
        self.y = torch.Tensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return {"X": self.X[idx], "y": self.y[idx]}


class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 1)
        self.criterion = MSELoss()

    def forward(self, inputs_id, labels=None):
        outputs = self.fc(inputs_id)
        loss = 0
        if labels is not None:
            loss = self.criterion(outputs, labels)
        return loss, outputs

    def train_dataloader(self):
        dataset = SimpleDataset()
        return DataLoader(dataset, batch_size=1000)

    def training_step(self, batch, batch_idx):
        input_ids = batch["X"]
        labels = batch["y"]
        loss, outputs = self(input_ids, labels)
        return {"loss": loss}

    def configure_optimizers(self):
        optimizer = Adam(self.parameters())
        return optimizer


if __name__ == '__main__':
    model = MyModel()
    trainer = pl.Trainer(max_epochs=100, gpus=1)
    trainer.fit(model)

    X = torch.Tensor([[1.0], [51.0], [89.0]])
    _, y = model(X)
    print(y)

Ported toy example with manual optimization (runs without any errors but no loss being printed in the progress bar):

import pytorch_lightning as pl
import numpy as np
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from sam import SAM
from ptadamw import AdamW
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn


class SimpleDataset(Dataset):
    def __init__(self):
        X = np.arange(10000)
        y = X * 2
        X = [[_] for _ in X]
        y = [[_] for _ in y]
        self.X = torch.Tensor(X)
        self.y = torch.Tensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return {"X": self.X[idx], "y": self.y[idx]}


class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 1)
        self.criterion = MSELoss()
        self.automatic_optimization = False

    def forward(self, inputs_id, labels=None):
        outputs = self.fc(inputs_id)
        loss = 0
        if labels is not None:
            loss = self.criterion(outputs, labels)
        return loss, outputs

    def train_dataloader(self):
        dataset = SimpleDataset()
        return DataLoader(dataset, batch_size=1000)

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()

        # first forward-backward pass
        loss_1 = self.compute_loss(batch)
        self.manual_backward(loss_1)
        optimizer.first_step(zero_grad=True)

        # second forward-backward pass
        loss_2 = self.compute_loss(batch)
        self.manual_backward(loss_2)
        optimizer.second_step(zero_grad=True)

        return {"loss": loss_1}

    def compute_loss(self, batch):
        input_ids = batch["X"]
        labels = batch["y"]
        loss, outputs = self(input_ids, labels)
        return loss

    def configure_optimizers(self):
        base_optimizer = Adam
        optimizer = SAM(model.parameters(), base_optimizer, lr=0.01, rho=0.05)
        return optimizer


if __name__ == '__main__':
    model = MyModel()
    trainer = pl.Trainer(max_epochs=100, gpus=1)
    trainer.fit(model)

    X = torch.Tensor([[1.0], [51.0], [89.0]])
    _, y = model(X)
    print(y)



### Error messages and logs

Baseline code: time, it/s, loss and v_num fields are printed correctly

Epoch 99: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 203.64it/s, loss=4.16e+07, v_num=32]
Trainer.fit stopped: max_epochs=100 reached.
Epoch 99: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 197.27it/s, loss=4.16e+07, v_num=32]
tensor([[ 1.9360],
[46.3258],
[80.0620]], grad_fn=)


Ported code: (loss is not being printed)

Epoch 99: 100%|████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 180.96it/s, v_num=33]
tensor([[ 2.8524],
[102.8746],
[178.8915]], grad_fn=)



### Environment

pytorch-lightning==1.9.5
lightning==2.1.3
python=3.10.13=hd12c33a_1_cpython
torch==2.1.2+cu118
torchaudio==2.1.2+cu118
torchmetrics==1.3.0
torchvision==0.16.2+cu118

You might want to cross-post the question into the lightning discussion board as it seems to be related to it’s usage.