Resuming DDP training produces different results from training from scratch

I’m learning DDP and want to realize a function which can resume training from the last snapshot to produce exactly same result as the model trained from scratch. I have read the REPRODUCIBILITY article and do the settings as possible as I can to guarantee a deterministic behavior.

My environment is:

Ubuntu 20.04
Python 3.10
Pytorch 1.12
CUDA Version 11.6

Here is my code named mnist_demo_multi_resume.py:

import torch
import torchvision
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

import os
import random
import numpy as np
from natsort import natsorted


def ddp_setup():
    init_process_group(backend='nccl')


def init_seeds(seed):
    # refer to https://pytorch.org/docs/stable/notes/randomness.html
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    os.environ["CUBLAS_WORKSPACE_CONFIG"] =":16:8"
    torch.use_deterministic_algorithms(mode=True, warn_only=True)
    

class Trainer:
    def __init__(self, model, loader, optimizer, save_every, snapshots_dir="./snapshots"):
        self.loader = loader
        self.optimizer = optimizer
        self.save_every = save_every
        
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.epoch = 0
        self.model = model.to(self.local_rank)

        if not os.path.exists(snapshots_dir):
            os.mkdir(snapshots_dir)
        last_spst = natsorted(os.listdir(snapshots_dir))[-1] if os.listdir(snapshots_dir) != [] else ' '
        snapshot_path = os.path.join(snapshots_dir, last_spst)
        if os.path.exists(snapshot_path):
            self._load_snapshot(snapshot_path)
        self.model = DDP(self.model, device_ids=[self.local_rank])
    
    def _load_snapshot(self, snapshot_path):
        snapshot = torch.load(snapshot_path)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.optimizer.load_state_dict(snapshot["OPTIMIZER_STATE"])
        self.epoch = snapshot["EPOCH"]
        print(f"Resuming training from snapshot {snapshot_path} at Epoch {self.epoch}")
    
    def _save_snapshot(self):
        snapshot = {
            "MODEL_STATE": self.model.module.state_dict(),
            "OPTIMIZER_STATE": self.optimizer.state_dict(),
            "EPOCH": self.epoch,
        }
        PATH = "snapshots/snapshot_{}.pt".format(self.epoch)
        torch.save(snapshot, PATH)
        print(f"Epoch {self.epoch} | Training snapshot saved at {PATH}")
    
    def _run_batch(self, batch_idx, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()
        self.loss = loss.item()
    
    def _run_epoch(self):
        for batch_idx, data in enumerate(self.loader, 1):
            source, targets = data
            source, targets = source.cuda(self.local_rank), targets.cuda(self.local_rank)
            self._run_batch(batch_idx, source, targets)
            if batch_idx == len(self.loader):
                print(f"[GPU{self.local_rank}] "
                    f"[Epoch {self.epoch}, {batch_idx}/{len(self.loader)} "
                    f"Batchsize: {self.loader.batch_size}] "
                    f"Loss: {self.loss:.4f} "
                    f"TgExample: {data[1][:20]}")
    
    def train(self, max_epochs):
        for epoch in range(self.epoch+1, max_epochs+1):
            self.epoch = epoch
            self.loader.sampler.set_epoch(epoch)
            self._run_epoch()
            if self.local_rank == 0 and self.epoch % self.save_every == 0:
                self._save_snapshot()

def prepare_dataloader(data_size=-1, batch_size=256):
    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    dataset = torchvision.datasets.MNIST("./Datasets", 
                                         train=True, 
                                         transform=trans, 
                                         target_transform=None, 
                                         download=True)
    if data_size != -1:
        dataset = Subset(dataset, range(data_size))  # first-n data
    loader = DataLoader(dataset=dataset, 
                        batch_size=batch_size,
                        sampler=DistributedSampler(dataset))
    return loader


def load_train_objs():
    model = torchvision.models.resnet50(num_classes=10)
    model.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    return model, optimizer


def main(max_epochs, save_every):
    init_seeds(233)
    ddp_setup()
    model, optimizer = load_train_objs()
    loader = prepare_dataloader(data_size=-1, batch_size=256)
    trainer = Trainer(model, loader, optimizer, save_every)
    trainer.train(max_epochs)
    
    destroy_process_group()
    
    
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('max_epochs', type=int)
    parser.add_argument('save_every', type=int)
    args = parser.parse_args()
    
    main(args.max_epochs, args.save_every)

Experiment

First, I repeat training of the model from scratch on multi-GPU DDP mode and can get the same training loss in each epoch iteration, which indicates that my settings do function.

run cmd torchrun --standalone --nproc_per_node=4 mnist_demo_multi_resume.py 4 1 and get (Fig. 1):

Then I delete the snapshot_4.pt, and run torchrun --standalone --nproc_per_node=4 mnist_demo_multi_resume.py 4 1 to resume training (Fig. 2):

We can see that in GPU0, the used training data label (TgExample) is same, but the loss in training from scratch mode is 0.0233 while in resuming training is 0.0295. Other losses are also different in the corresponding GPU.

But if I repeat the second operation, i.e. deleting snapshot_4.pt and resuming training, I can get the same results as shown in Fig. 2.

I am very confused about this situation and wonder what could be the reason for the above difference.

At first I suspected that the random number generator in Dataloader made trouble behind the scene. But the following test further did confuse me.

I change the number in Dataset using torch.utils.data.Subset and just use the first 1000 data for trainng, i.e. loader = prepare_dataloader(data_size=1000, batch_size=256) in my code.

Then here is the results of training from scratch (Fig. 3):

and here is the results of resuming training (Fig. 4):

As shown in Fig. 3 and Fig. 4, they produce the same results as expected.

In summary, I’m not sure what important information I’m missing that would lead to the above situation. Kindly hope someone can solve the problem.

Thanks for providing a way to reproduce! This makes our lives much easier.

I have spent two hours digging into this and have not figured it out, so I am sharing some learnings to possibly expedite in case someone else picks up:

  • Batch norm seems to be checkpointed correctly (e.g. running_mean and running_var are correct and passing track_running_stats=False to the BatchNorm2d instances does not fix the issue).
  • torch.manual_seed(0) immediately before creating the data loader iterator for batch_idx, data in enumerate(self.loader, 1): does not fix the issue.
  • Saving the RNG state via torch.get_rng_state() and loading it via torch.set_rng_state() does not fix the issue.
  • Passing shuffle=False for DistributedSampler does not fix the issue (which was promising because shuffle=True by default for DistributedSampler). Passing shuffle=False for DataLoader should not be needed since it defaults to False.
  • Running dataset = Subset(dataset, range(len(dataset))) does not fix the issue (which was promising in case Subset was enforcing determinism when without it was not).

Something else to try may be to see if this issue replicates when not using DDP at all.

Thanks for your time. I’ve tried these methods you mentioned before, and got the same answers as yours.

Another thing is that, if I use only one node to run the script, the issue doesn’t happen. The command is here:

rm snapshots/*.pt; 
torchrun --standalone --nproc_per_node=1 mnist_demo_multi_resume.py 4 1; 
rm snapshots/snapshot_4.pt; 
torchrun --standalone --nproc_per_node=1 mnist_demo_multi_resume.py 4 1

and I can get:

So, I think there may be something underlying the use of multi-GPU mode that cannot be deterministic, and I would like to know exactly what it is that causes the issue.

Sorry, I do not have the time at the moment to try this myself, but I had one idea: Each rank should have different batch norm statistics since those depend on the data the rank has processed. However, you only save the state dict from rank 0, and each rank loads that rank-0 state dict, imposing rank 0’s batch norm statistics onto the other ranks.

There are two things you may try:

  1. Try using convert_sync_batch_norm() to convert the BatchNorm2d instances in ResNet50 into SyncBatchNorm instances, which will synchronize statistics across ranks.
  2. Save per-rank state dicts from the DDP instance itself (i.e. self.model and not self.model.module) and load them into the new DDP instance. This is because upon creating a new DDP instance, it will broadcast the parameters and buffers from rank 0 to all ranks, which would defeat the purpose of saving per-rank state dicts.

Thank you for you quick reply!

I’ve tried the ways you said, adding convert_sync_batch_norm() and saving per-rank state_dicts from DDP intstance itself. But it doesn’t fix the issue. Here is the modified code:

import torch
import torchvision
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

import os
import random
import numpy as np


def ddp_setup():
    init_process_group(backend='nccl')


def init_seeds(seed):
    # refer to https://pytorch.org/docs/stable/notes/randomness.html
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    os.environ["CUBLAS_WORKSPACE_CONFIG"] =":16:8"
    torch.use_deterministic_algorithms(mode=True, warn_only=True)
    

class Trainer:
    def __init__(self, model, loader, optimizer, save_every, sync_bn, snapshots_dir="./snapshots"):
        self.loader = loader
        self.optimizer = optimizer
        self.save_every = save_every
        
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.epoch = 0
        self.model = model.to(self.local_rank)
        if sync_bn:
            print(f"[GPU{self.local_rank}] Enable SyncBatchNorm")
            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        else:
            print(f"[GPU{self.local_rank}] Disable SyncBatchNorm")
        self.model = DDP(self.model, device_ids=[self.local_rank])
        
        # Find the latest snapshot
        if not os.path.exists(snapshots_dir):
            os.mkdir(snapshots_dir)
        last_epoch = int(sorted([s.strip('snapshot_r').rstrip('.pt').lstrip('0123456789').lstrip('_') 
                                 for s in os.listdir(snapshots_dir)])[-1]) \
                                     if os.listdir(snapshots_dir) != [] else ' '
        snapshot_path = os.path.join(snapshots_dir, 'snapshot_r{}_{}.pt'.format(self.local_rank, last_epoch))
        if os.path.exists(snapshot_path):
            self._load_snapshot(snapshot_path)
    
    def _load_snapshot(self, snapshot_path):
        snapshot = torch.load(snapshot_path, map_location=torch.device(self.local_rank))
        # self.model.module.load_state_dict(snapshot["MODEL_STATE"])
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.optimizer.load_state_dict(snapshot["OPTIMIZER_STATE"])
        self.epoch = snapshot["EPOCH"]
        print(f"[GPU{self.local_rank}] Resuming training at Epoch {self.epoch} from {snapshot_path} ")
    
    def _save_snapshot(self):
        snapshot = {
            # "MODEL_STATE": self.model.module.state_dict(),
            "MODEL_STATE": self.model.state_dict(),
            "OPTIMIZER_STATE": self.optimizer.state_dict(),
            "EPOCH": self.epoch,
        }
        PATH = "snapshots/snapshot_r{}_{}.pt".format(self.local_rank, self.epoch)
        torch.save(snapshot, PATH)
        print(f"Epoch {self.epoch} | Training snapshot saved at {PATH}")
    
    def _run_batch(self, batch_idx, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()
        self.loss = loss.item()
    
    def _run_epoch(self):
        for batch_idx, data in enumerate(self.loader, 1):
            source, targets = data
            source, targets = source.cuda(self.local_rank), targets.cuda(self.local_rank)
            self._run_batch(batch_idx, source, targets)
            if batch_idx == len(self.loader):
                print(f"[GPU{self.local_rank}] "
                    f"[Epoch {self.epoch}, {batch_idx}/{len(self.loader)} "
                    f"Batchsize: {self.loader.batch_size}] "
                    f"Loss: {self.loss:.4f}, "
                    f"BN1-Wt: {[round(w, 4) for w in self.model.module.bn1.weight[:5].tolist()]}, "
                    # f"BN1-Bias: {[round(w, 4) for w in self.model.module.bn1.bias[:5].tolist()]} "
                    # f"TgExample: {data[1][:20]}"
                    )
    
    def train(self, max_epochs):
        for epoch in range(self.epoch+1, max_epochs+1):
            self.epoch = epoch
            # torch.manual_seed(233 + epoch)
            self.loader.sampler.set_epoch(epoch)
            self._run_epoch()
            if self.epoch % self.save_every == 0:  # self.local_rank == 0 and 
                self._save_snapshot()

def prepare_dataloader(data_size=-1, batch_size=256):
    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    dataset = torchvision.datasets.MNIST("./Datasets", 
                                         train=True, 
                                         transform=trans, 
                                         target_transform=None, 
                                         download=True)
    if data_size != -1:
        dataset = Subset(dataset, range(data_size))  # Fetching top data_size samples as subset
    loader = DataLoader(dataset=dataset, 
                        batch_size=batch_size,
                        sampler=DistributedSampler(dataset))
    return loader


def load_train_objs():
    model = torchvision.models.resnet50(num_classes=10)
    model.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    return model, optimizer


def main(max_epochs, save_every, sync_bn):
    init_seeds(233)
    ddp_setup()
    model, optimizer = load_train_objs()
    loader = prepare_dataloader(data_size=-1, batch_size=256)
    trainer = Trainer(model, loader, optimizer, save_every, sync_bn)
    trainer.train(max_epochs)
    
    destroy_process_group()
    
    
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('max_epochs', type=int)
    parser.add_argument('save_every', type=int)
    parser.add_argument('--sync_bn', action='store_true')
    args = parser.parse_args()
    
    main(args.max_epochs, args.save_every, args.sync_bn)

Run the command:

rm snapshots/*.pt; 
torchrun --standalone --nproc_per_node=gpu mnist_demo_multi_resume.py 4 1 --sync_bn; 
rm snapshots/*_4.pt; 
torchrun --standalone --nproc_per_node=gpu mnist_demo_multi_resume.py 4 1 --sync_bn

and get the following results (Fig. 1):

We can see that all the difference still exists.

If I switch --sync_bn to False, i.e.:

rm snapshots/*.pt; 
torchrun --standalone --nproc_per_node=gpu mnist_demo_multi_resume.py 4 1; 
rm snapshots/*_4.pt; 
torchrun --standalone --nproc_per_node=gpu mnist_demo_multi_resume.py 4 1

and get (Fig. 2):

The results in Fig. 2 are different from those in Fig .1, which may mean that the convert_sync_batchnorm do function in Fig. 1.

I think maybe there are something wrong in my code? It doesn’t make sense…

Actually all the learnable parameters (e.g BN’s weights and bias) will be synchronized through DDP. Though the running_mean and running_var are processed individually in each GPU, they don’t affect the training process.

Another idea: You can consider changing the model from ResNet50 to a single weight matrix (i.e. nn.Linear(in_dim, out_dim, bias=False). If that still shows the issue, then we know that it is not related to batch norm, and we may iterate faster.

I implement an easy model without BN like this:

class MyModel(torch.nn.Module):
    def __init__(self, num_classes=1000):
        super(MyModel, self).__init__()
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(1, 64, (3, 3), (1, 1), (1, 1)),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),

            torch.nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(64 * 7 * 7, 2048),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(2048, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

And replace the model in my code, such as:

from model import MyModel

def load_train_objs():
    # model = torchvision.models.resnet50(num_classes=10)
    # model.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
    model = MyModel(num_classes=10)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    return model, optimizer

Then run the cmd:

rm snapshots/*.pt; 
torchrun --standalone --nproc_per_node=gpu mnist_demo_multi_resume.py 4 1; 
rm snapshots/*_4.pt; 
torchrun --standalone --nproc_per_node=gpu mnist_demo_multi_resume.py 4 1

and get:

...
[GPU0] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.13176,
[GPU1] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.12624, 
[GPU2] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.16328, 
[GPU3] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.16279, 
...
[GPU1] Resuming training at Epoch 3 from ./snapshots/snapshot_r1_3.pt
[GPU0] Resuming training at Epoch 3 from ./snapshots/snapshot_r0_3.pt
[GPU3] Resuming training at Epoch 3 from ./snapshots/snapshot_r3_3.pt
[GPU2] Resuming training at Epoch 3 from ./snapshots/snapshot_r2_3.pt
[GPU0] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.13172,
[GPU1] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.12625,  
[GPU2] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.16326, 
[GPU3] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.16279,
...

These are some difference in the results, such as 0.13176 vs 0.13172.

When I append a BN layer in MyModel:

...
self.features = torch.nn.Sequential(
    torch.nn.Conv2d(1, 64, (3, 3), (1, 1), (1, 1)),
    torch.nn.ReLU(inplace=True),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),

    torch.nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
    torch.nn.BatchNorm2d(num_features=64),  # <- Append a BN layer here
    torch.nn.ReLU(inplace=True),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),
)
..

and run the same cmd with --sync_bn:

rm snapshots/*.pt; 
torchrun --standalone --nproc_per_node=gpu mnist_demo_multi_resume.py 4 1 --sync_bn; 
rm snapshots/*_4.pt; 
torchrun --standalone --nproc_per_node=gpu mnist_demo_multi_resume.py 4 1 --sync_bn

I get the following results:

...
[GPU0] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.04246, 
[GPU1] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.02873,
[GPU2] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.04641, 
[GPU3] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.05780, 
...
[GPU2] Resuming training at Epoch 3 from ./snapshots/snapshot_r2_3.pt
[GPU0] Resuming training at Epoch 3 from ./snapshots/snapshot_r0_3.pt
[GPU3] Resuming training at Epoch 3 from ./snapshots/snapshot_r3_3.pt
[GPU1] Resuming training at Epoch 3 from ./snapshots/snapshot_r1_3.pt
[GPU0] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.04237, 
[GPU1] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.02873, 
[GPU2] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.04649,
[GPU3] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.05767, 
...

The difference becomes much larger.

I think BN isn’t the impact factor in the issue, and the difference comes from the loss computation.

Perhaps I’m missing something important in the loss function when in DDP mode? Because if I just use 1 GPU to do the experiments above, it always produces exactly same results.

When I use some easy loss function and use ResNet50 without SyncBN, everything goes well, such as:

def _run_batch(self, batch_idx, source, targets):
    ....
    # loss = F.cross_entropy(output, targets)
    loss = (sum(output.argmax(1) != targets) / len(targets)).requires_grad_(True)  # <- some loss_fn
    ...

The results are:

...
[GPU0] [Epoch 1, 59/59 Batchsize: 256] Loss: 0.90789,
[GPU1] [Epoch 1, 59/59 Batchsize: 256] Loss: 0.92763,
[GPU2] [Epoch 1, 59/59 Batchsize: 256] Loss: 0.85526,
[GPU3] [Epoch 1, 59/59 Batchsize: 256] Loss: 0.90789,

[GPU0] [Epoch 2, 59/59 Batchsize: 256] Loss: 0.88158,
[GPU1] [Epoch 2, 59/59 Batchsize: 256] Loss: 0.90132,
[GPU2] [Epoch 2, 59/59 Batchsize: 256] Loss: 0.90789,
[GPU3] [Epoch 2, 59/59 Batchsize: 256] Loss: 0.91447,

[GPU0] [Epoch 3, 59/59 Batchsize: 256] Loss: 0.91447,
[GPU1] [Epoch 3, 59/59 Batchsize: 256] Loss: 0.95395,
[GPU2] [Epoch 3, 59/59 Batchsize: 256] Loss: 0.88816,
[GPU3] [Epoch 3, 59/59 Batchsize: 256] Loss: 0.92763,

[GPU0] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.94737,
[GPU1] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.92105,
[GPU2] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.90789,
[GPU3] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.90132,

...
[GPU0] Resuming training at Epoch 3 from ./snapshots/snapshot_r0_3.pt
[GPU3] Resuming training at Epoch 3 from ./snapshots/snapshot_r3_3.pt
[GPU1] Resuming training at Epoch 3 from ./snapshots/snapshot_r1_3.pt
[GPU2] Resuming training at Epoch 3 from ./snapshots/snapshot_r2_3.pt
[GPU0] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.94737,
[GPU1] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.92105,
[GPU2] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.90789,
[GPU3] [Epoch 4, 59/59 Batchsize: 256] Loss: 0.90132,
...

We can see that the loss in [Epoch 4] in resuming training is same as that in training from scratch.


I do some other experiments and find that loss function do cause the loss fluctuations. Since I use the following loss function to replace F.cross_entropy(), which are mathematically equivalent:

# loss = F.cross_entropy(output, targets)
predict = torch.log(torch.nn.functional.softmax(output, dim=1))
loss = torch.nn.functional.nll_loss(predict, targets)

To improve the numerical stability, I add an epsilon value in the torch.log(... + eps). Run the above script and we will se the losses from using eps=1e-6, eps=1e-8, and eps=1e-10 are different from each other.

This issue has still not been resolved…I don’t understand why everything goes well in SINGLE GPU mode, while losses are different in MULTI-GPU mode (loss from resuming training v.s. loss from completly training). In other words, does the use of DDP cause the numerical unstability?

Are you seeing any issues when the numerically stable F.cross_entropy method is used instead of your custom approach? Subtracting a small eps value won’t stabilize your approach and you might want to apply the LogSumExp trick or just use the built-in methods.