Issue with Gradient Clipping When Using Stochastic Weight Averaging in PyTorch Lightning

Hi everyone,

I’m encountering an issue with gradient clipping at the final epoch when using Stochastic Weight Averaging (SWA) in PyTorch Lightning. The error occurs consistently at the last epoch of training. Below is my script:

import time
import hydra
import lightning as pl
import torch
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, StochasticWeightAveraging
from model import EarthQuakeModel
from omegaconf import DictConfig
from torchgeo.datamodules import QuakeSetDataModule

@hydra.main(config_path="configs", config_name="mobilenet", version_base=None)
def main(args: DictConfig):
    pl.seed_everything(100195, workers=True)
    torch.set_float32_matmul_precision("medium")

    model = EarthQuakeModel(**args.model)
    data_module = QuakeSetDataModule(**args.dataset)
    data_module.train_aug = None

    run_id = time.strftime('%Y%m%d-%H%M%S')
    wandb_logger = WandbLogger(
        project="smac",
        name=run_id,
        log_model="all"
    )
    wandb_logger.watch(model, log="gradients")

    checkpoint_callback = ModelCheckpoint(
        monitor="val_mae",
        dirpath=f"checkpoints/{run_id}",
        filename="earthquake-detection-{epoch:02d}-{val_mae:.4f}",
        save_top_k=3,
        mode="min",
    )
    lr_monitor = LearningRateMonitor(logging_interval="step")

    trainer = pl.Trainer(
        **args.trainer,
        deterministic=True,
        callbacks=[
            StochasticWeightAveraging(swa_lrs=1e-2),
            checkpoint_callback,
            lr_monitor
        ],
        logger=wandb_logger,
        log_every_n_steps=50,
        precision="32-true",
        gradient_clip_val=0.5,
        gradient_clip_algorithm="value"
    )

    trainer.fit(model, datamodule=data_module)

if __name__ == "__main__":
    main()

Error Message:

Error executing job with overrides: [‘trainer.max_epochs=20’]

RuntimeError: Expected !nested_tensorlist[0].empty() to be true, but got false.

I am using PyTorch Lightning 2.2.3 and PyTorch 2.3.0.

I have tried using different swa_epoch_start values and changing the precision setting, but the issue persists. The gradient clipping method and SWA seem to work fine independently but cause this error when used together.

Has anyone encountered a similar issue or have insights on why this might be happening? Any suggestions on how to debug or fix this would be greatly appreciated!

Thanks in advance!