PyTorch Lightning reproducibility

I’m experiencing an issue with reproducibility in PyTorch Lightning.

Despite setting deterministic=True, pl.seed_everything(42) and workers=True, I’m getting different validation losses for the same epoch when changing the number of epochs in the trainer.

For example, when I set the number of epochs to 1, I get a validation loss of 5.89, but when I set it to 2, I get a validation loss of 5.13 on the first epoch. I’ve verified that the results are reproducible when running the same number of epochs multiple times (get the same sequence of validation losses), but changing the number of epochs yields different results for the same epoch.

What is the cause of this behavior and is there a way to ensure reproducible results across different numbers of epochs?

I am using PyTorch Lightning version 2.2.3

import time

import hydra
import lightning as pl
import torch
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from model import EarthQuakeModel
from omegaconf import DictConfig
from torchgeo.datamodules import QuakeSetDataModule


@hydra.main(config_path="configs", config_name="default", version_base=None)
def main(args: DictConfig):
    pl.seed_everything(42, workers=True)
    torch.set_float32_matmul_precision("medium")

    data_module = QuakeSetDataModule(**args.dataset)
    model = EarthQuakeModel(**args.model)

    experiment_id = time.strftime("%Y%m%d-%H%M%S")

    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=f"checkpoints/{experiment_id}",
        filename="earthquake-detection-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )

    lr_monitor = LearningRateMonitor(logging_interval="step")

    trainer = pl.Trainer(
        **args.trainer,
        # deterministic=True,
        callbacks=[checkpoint_callback, lr_monitor],
        log_every_n_steps=50,
        precision="32-true",
    )

    trainer.fit(model, datamodule=data_module)


if __name__ == "__main__":
    main()

I don’t know where the epochs are set, but you could check if it changes the number of validation runs etc. If so, you could then check if running a validation loop calls into the PRNG and is thus changing future random numbers.

I usually start looking for a solution from simple mistakes too! This is check if it changes the number of validation runs and whether running the check cycle causes a PRNG! And then there’s more to the situation))