Encountered strange overfitting during training, with `num_worker > 0`, when shuffling is off

Hi,

Overview

Recently I have encountered some strange overfitting when setting num_workers > 0. It turned out we accidentally had shuffle=False, but our models were still training fine when num_workers=0. What’s strange is when we set num_workers > 0, the models started to overfit – training loss was going down, but validation loss started to go up after a while. We solved it by setting shuffle=True; but I would still be interested in getting some insight into why this happened, because I find it strange that no shuffling worked fine, but then possibly a little bit of shuffling or data duplication introduced by the distributed training caused overfitting, but then fully shuffling the data caused it to train normally again.

Some background

We are training an algorithm called Noise2Void, which is a self-supervised method for denoising images that have uncorrelated noise. The basic idea is that it cannot learn the uncorrelated noise, but it can learn underlying structures present in an image. This means if we train using L2 loss it should learn to predict the expected mean of a pixel given the structure contained in some surrounding pixels.

This is implemented as part of our Python library CAREamics, that aims to make this algorithm (N2V) and others more accessible to the scientific community.

Example result using DataLoader parameters num_workers=4 and shuffle=False

(Bad result)

Side note

We are using PyTorch Lightning and we accidentally didn’t have shuffle=True because of a misunderstanding that PyTorch Lightning automatically applies shuffling to the train dataloader by using a torch.utils.data.DistributedSampler. After some further investigation it turned out that the data is in fact not shuffled during training unless shuffle=True is explicitly set in the train_dataloader.

I haven’t brought this up in the Lightning forums because I don’t believe it is a result of lightning, but if it turns out it is, I will bring it up there.

More notes

This is also raised as an issue on our GitHub repo where you can see good training results examples, and also figures plotting the train loss and validation loss.

Thanks!

Hi, thanks for the report. The problem looks very interesting! We’ve tried a few sanity checks on the data returned by the dataloader but got no luck. Can you share the full script to reproduce the issue?

We initialized a toy dataloader like this:

import numpy as np
import torch
import pytorch_lightning
from careamics.lightning import (  
    create_careamics_module,
    create_predict_datamodule,
    create_train_datamodule,
)
from careamics.prediction_utils import convert_outputs
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
)

# training data
rng = np.random.default_rng(42)
train_array = rng.integers(0, 255, (64, 64)).astype(np.float32)
val_array = rng.integers(0, 255, (64, 64)).astype(np.float32)

# create lightning module
model = create_careamics_module(  
    algorithm="n2v",
    loss="n2v",
    architecture="UNet",
)

# create data module
data = create_train_datamodule(
    train_data=train_array,
    val_data=val_array,
    data_type="array",
    patch_size=(16, 16),
    axes="YX",
    batch_size=2,
    dataloader_params={"num_workers": 4},
    # transforms=[]
)

# create trainer
trainer = Trainer(  
    max_epochs=1,
    default_root_dir=".",
    callbacks=[
        ModelCheckpoint(  
            dirpath="./checkpoints",
            filename="basic_usage_lightning_api",
        )
    ],
)

data.prepare_data()
data.setup()

# check the dataloader 
for batch in data.train_dataloader():
    print(batch)
    print("===")

Hi,

Of course!

In my python env I first installed PyTorch and then, careamics using pip install "careamics[dev,examples,wandb]==0.0.4.2".

The following script downloads an example dataset that has artificially added gaussian noise. I use our CAREamist API for convenience but I have tried variations where I instantiate the different lightning modules individually, with the same results.

In another experiment I wrote a lightning callback to save the first n batches each epoch, to see if the data was shuffled, and as far as I can tell, it is only shuffled in the case shuffle is explicitly set to True.

from pathlib import Path
import glob

import numpy as np
import matplotlib.pyplot as plt
import tifffile
import wandb

from careamics_portfolio import PortfolioManager
from careamics import CAREamist
from careamics.config import create_n2v_configuration

# instantiate data portfolio manage
portfolio = PortfolioManager()

# and download the data
root_path = Path("./data")
files = portfolio.denoising.N2V_BSD68.download(root_path)

# create paths for the data
data_path = Path(root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data")
train_path = data_path / "train"
val_path = data_path / "val"
test_path = data_path / "test" / "images"
gt_path = data_path / "test" / "gt"

# select single file for displaying results
test_files = glob.glob("*.tif*", root_dir=test_path)
test_files.sort()
test_file = test_files[0]
test_image = tifffile.imread(test_path/test_file)

dataloader_params_combinations = [
    {"shuffle": False, "num_workers": 0},
    {"shuffle": False, "num_workers": 4},
    {"shuffle": True, "num_workers": 4},
]
predictions = {}
for dataloader_params in dataloader_params_combinations:
    experiment_name = (
        f"shuffle-{dataloader_params['shuffle']}"
        f"-num_workers-{dataloader_params['num_workers']}"
    )
    config = create_n2v_configuration(
        experiment_name=experiment_name,
        data_type="tiff",
        axes="SYX",
        patch_size=(64, 64),
        batch_size=64,
        num_epochs=10,
        augmentations=[], # No data augmentation
        logger="wandb",
        dataloader_params=dataloader_params
    )

    careamist = CAREamist(source=config, work_dir=Path(__file__).parent)
    careamist.train(train_source=train_path, val_source=val_path)

    experiment_prediction = careamist.predict(
        source=test_path / test_file,
        data_type=config.data_config.data_type,
        axes= "YX",
        image_means=config.data_config.image_means,
        image_stds=config.data_config.image_stds,
        tile_size=(128, 128),
        tile_overlap=(48, 48),
        tta_transforms=False
    )
    predictions[experiment_name] = experiment_prediction[0]
    wandb.finish()

fig, axes = plt.subplots(2, 2)
fig.set_size_inches(16, 12)
axes.flatten()[0].imshow(test_image, cmap="gray")
axes.flatten()[0].set_title("Input")
for i, dataloader_params in enumerate(dataloader_params_combinations):
    experiment_name = (
        f"shuffle-{dataloader_params['shuffle']}"
        f"-num_workers-{dataloader_params['num_workers']}"
    )
    axes.flatten()[i+1].imshow(np.squeeze(predictions[experiment_name]), cmap="gray")
    axes.flatten()[i+1].set_title(experiment_name)
fig.suptitle("Predictions")
fig.tight_layout()
fig.savefig(Path(__file__).parent/"prediction_figure.png")
plt.show()

My most recent run of this script has the following training losses, plotted by WandB:

And the inference results: