Bug description
I ported an simple toy example in pytorch lightning and found that when manual optimization is turned on (to implement a custom optimizer), the loss doesn’t get printed. Not sure if I am missing to set some something.
Any help is appreciated.
What version are you seeing the problem on?
v2.1
How to reproduce the bug
Baseline toy example (works fine and prints loss correctly):
import pytorch_lightning as pl
import numpy as np
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from ptadamw import AdamW
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
class SimpleDataset(Dataset):
def __init__(self):
X = np.arange(10000)
y = X * 2
X = [[_] for _ in X]
y = [[_] for _ in y]
self.X = torch.Tensor(X)
self.y = torch.Tensor(y)
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
return {"X": self.X[idx], "y": self.y[idx]}
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
self.criterion = MSELoss()
def forward(self, inputs_id, labels=None):
outputs = self.fc(inputs_id)
loss = 0
if labels is not None:
loss = self.criterion(outputs, labels)
return loss, outputs
def train_dataloader(self):
dataset = SimpleDataset()
return DataLoader(dataset, batch_size=1000)
def training_step(self, batch, batch_idx):
input_ids = batch["X"]
labels = batch["y"]
loss, outputs = self(input_ids, labels)
return {"loss": loss}
def configure_optimizers(self):
optimizer = Adam(self.parameters())
return optimizer
if __name__ == '__main__':
model = MyModel()
trainer = pl.Trainer(max_epochs=100, gpus=1)
trainer.fit(model)
X = torch.Tensor([[1.0], [51.0], [89.0]])
_, y = model(X)
print(y)
Ported toy example with manual optimization (runs without any errors but no loss being printed in the progress bar):
import pytorch_lightning as pl
import numpy as np
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from sam import SAM
from ptadamw import AdamW
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
class SimpleDataset(Dataset):
def __init__(self):
X = np.arange(10000)
y = X * 2
X = [[_] for _ in X]
y = [[_] for _ in y]
self.X = torch.Tensor(X)
self.y = torch.Tensor(y)
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
return {"X": self.X[idx], "y": self.y[idx]}
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
self.criterion = MSELoss()
self.automatic_optimization = False
def forward(self, inputs_id, labels=None):
outputs = self.fc(inputs_id)
loss = 0
if labels is not None:
loss = self.criterion(outputs, labels)
return loss, outputs
def train_dataloader(self):
dataset = SimpleDataset()
return DataLoader(dataset, batch_size=1000)
def training_step(self, batch, batch_idx):
optimizer = self.optimizers()
# first forward-backward pass
loss_1 = self.compute_loss(batch)
self.manual_backward(loss_1)
optimizer.first_step(zero_grad=True)
# second forward-backward pass
loss_2 = self.compute_loss(batch)
self.manual_backward(loss_2)
optimizer.second_step(zero_grad=True)
return {"loss": loss_1}
def compute_loss(self, batch):
input_ids = batch["X"]
labels = batch["y"]
loss, outputs = self(input_ids, labels)
return loss
def configure_optimizers(self):
base_optimizer = Adam
optimizer = SAM(model.parameters(), base_optimizer, lr=0.01, rho=0.05)
return optimizer
if __name__ == '__main__':
model = MyModel()
trainer = pl.Trainer(max_epochs=100, gpus=1)
trainer.fit(model)
X = torch.Tensor([[1.0], [51.0], [89.0]])
_, y = model(X)
print(y)
### Error messages and logs
Baseline code: time, it/s, loss and v_num fields are printed correctly
Epoch 99: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 203.64it/s, loss=4.16e+07, v_num=32]
Trainer.fit
stopped: max_epochs=100
reached.
Epoch 99: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 197.27it/s, loss=4.16e+07, v_num=32]
tensor([[ 1.9360],
[46.3258],
[80.0620]], grad_fn=)
Ported code: (loss is not being printed)
Epoch 99: 100%|████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 180.96it/s, v_num=33]
tensor([[ 2.8524],
[102.8746],
[178.8915]], grad_fn=)
### Environment
pytorch-lightning==1.9.5
lightning==2.1.3
python=3.10.13=hd12c33a_1_cpython
torch==2.1.2+cu118
torchaudio==2.1.2+cu118
torchmetrics==1.3.0
torchvision==0.16.2+cu118