Training a PyTorch Lightning model but loss didn't improve (Trade-off batch_size & num_workers?)

I created a model using the Pytorch Lightning Module, and I have a machine with 8 CPUs and a GPU. Batch size = 8 and num workers = 8 are the values I’ve chosen. The loss function is about dice loss between masks and predictions (it’s about 2D MRI slices with masks (2 classes…)), but the dice loss did not improve at all (= 1). Then, when I changed the num workers to = 2 (the loss function started to decrease).
Please, what’s the best way to describe that? Knowing that the maximum number of epochs is 75, I finally settled on batch size = 16 and num workers = 4 (which works and even better than batch_size= 8 and num_workers = 2). Which of batch size and num workers, in your perspective, is the best tradeoff? Thank you as well.

This is unexpected as the number of workers (assuming you are referring to the num_workers in the torch.utils.data.DataLoader) should not change any training properties at all and is only used to preload the next data batches. The shuffling, processing, etc. of the data is also the same.
Could you rerun a quick test (not to full convergence) using different seeds and see if your previous setting was just “unlucky”?

Yes it is about the num_workers in torch.utils.data.DataLoader
In fact, I ran some further tests and discovered that no matter what batch size I use (8, 16,…), when I use num workers =8 (the maximum number of CPUs I have), the loss remains constant at 1. Given that my GPU is a Quadro M4000, my computer (virtual machine) has the following characteristics: 30 GiB|8 CPUs.
Here’s the result (and all subsequent epchos; there will be no improvement):

Epoch 0, global step 121: ‘Val Dice’ reached 1.00000 (best 1.00000), saving model to ‘./logs/lightning_logs/version_1/checkpoints/epoch=0-step=121.ckpt’ as top 1

Epoch 1, global step 242: ‘Val Dice’ was not in top 1

Epoch 2, global step 363: ‘Val Dice’ was not in top 1

And here’s the (successful) output when I use num_workers=4 , the primary difference being that I didn’t use the maximum number of CPUs (which is equivalent to 8) (and here, I used batch size = 16; as I previously stated, I believe batch size has nothing to do with this…):
The loss has started to decrease since the first epoch:

Epoch 0, global step 121: ‘Val Dice’ reached 0.45927 (best 0.45927), saving model to ‘./logs/lightning_logs/version_0/checkpoints/epoch=0-step=121.ckpt’ as top 1

Epoch 1, global step 242: ‘Val Dice’ reached 0.36336 (best 0.36336), saving model to ‘./logs/lightning_logs/version_0/checkpoints/epoch=1-step=242.ckpt’ as top 1

Epoch 2, global step 363: ‘Val Dice’ reached 0.36279 (best 0.36279), saving model to ‘./logs/lightning_logs/version_0/checkpoints/epoch=2-step=363.ckpt’ as top 1

Continue reading (please read my previous answer before this one)
Then, using a counterexample, I learned that (maybe it isn’t so much about the maximum number of workers as it is about the relationship between num workers and batch size). Allow me to explain.

I ran a test with batch size = 8 and num workers = 4 (note that I didn’t use all of the CPUs). Furthermore, the loss remains constant at one and does not improve (so here we understand that the problem does not come from the fact that I used the maximum number of CPUs).

As a result, it appears that the internal structure of the Dataloader only permits bach size = 4 x num workers (naively speaking….).
Using this naive observation, I used batch size = 8, but num workers = 2 (which regrettably did not work, thus it was abandoned).

It’s this behavior that I can’t describe or understand that concerns me. How can I be confident of what I’m training (I could think my model isn’t working when the problem isn’t caused by the model…), that’s what concerns me about this (and should I randomly change batch size and num workers each time till it works…?)

=============

To be clear, I include a warning that appears every time I run.
trainer.fit(model, train_loader, val_loader)
Regardless of whether the test was successful (the loss began to decrease as expected) or failed (the loss remained = 1), this warning shows (but I overlooked it):

warnings.warn( /opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called full_state_update that has not been set for this class (_ResultMetric). The property determines if update by default needs access to the full metric state. If this is not the case, significant speedups can be achieved and we recommend setting this to False. We provide an checking function from torchmetrics.utilities import check_forward_no_full_state that can be used to check if the full_state_update=True (old and potential slower behaviour, default for now) or if full_state_update=False can be used safely. warnings.warn(*args, **kwargs)

Thanks for the detailed update!
I’m not familiar enough with torchmetrics so don’t know if the warning is related to the issue or not (I would guess it’s unrelated).

In any case, your use case sounds as if the seeding inside each worker might not work properly and thus you might use the same “random” transformations to create each batch.
Take a look at this section:

By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily) or a specified generator. However, seeds for other libraries may be duplicated upon initializing workers, causing each worker to return identical random numbers. (See this section in FAQ.).

and this:

You are likely using other libraries to generate random numbers in the dataset and worker subprocesses are started via fork. See torch.utils.data.DataLoader’s documentation for how to properly set up random seeds in workers with its worker_init_fn option.

which points to the worker_init_fn to add a proper seeding e.g. via:

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=seed_worker,
    generator=g,
)

Could you check if any other library is used (e.g. numpy) and add the worker_init_fn to the DataLoader?

1 Like

Thanks a lot! Yes, now with your solution it works perfectly with any combination (number of workers, batch size).
A very big thank you.