Unet requires automatic_optimization=False but then validation_step is not invoked (Lightning)

I am using Pytorch Lightning to train a UNet model for a binary segmentation task.

Model

import lightning.pytorch as pl
import segmentation_models_pytorch as smp
import torch.nn as nn

class UNet(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()

        self.model = smp.Unet(encoder_name='efficientnet-b2', 
                              encoder_weights='imagenet', 
                              classes=num_classes, 
                              activation='sigmoid')
    
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.automatic_optimization = True
        self.save_hyperparameters()

    def configure_optimizers(self) -> Any:
        return super().configure_optimizers()

    def training_step(self, batch, batch_idx):
        img, gt = batch
        output = self.model(img)
        loss = self.loss_fn(output, gt)
        self.log('train_loss', loss, on_epoch=True, on_step=True)
        return loss

    def validation_step(self, batch, batch_idx):
        img, gt = batch
        output = self.model(img)
        loss = self.loss_fn(output, gt)
        self.log('val_loss', loss, on_epoch=True, on_step=True)

When I have automatic_optimization set to True I get this error:

File "/Users/me/.pyenv/versions/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

When I have automatic_optimization set to False the Lightning invokes the training_step successfully, but the validation_step never gets called.

What’s the suggested approach when using Lightning? I’ve been reading the Manual Optimization section but not sure if it’s right to define the validation step inside training_step.

I found the issue. Setting automatic_optimization to False is the way to go (correct me if I am wrong).

The reason my validation_step was not being called was because my validation dataset was returning 0 items :man_facepalming: