Unexplained behaviour in accumulate gradients vs in a ddp setting - why are the gradients different?

Can someone explain the discrepancy I’m noticing for this snippet of code on a 4-GPU machine (I can replicate this on a 2 GPU machine as well)

import pytorch_lightning as ptl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import time

# Configs
BATCH_SIZE = 1
NUM_LABELS = 3
INPUT_DIM = 4
HIDDEN_DIM = 4
SEQ_LEN = 4
VOCAB_SIZE = 6
SEED = 42
NUM_DEVICES = 4
NUM_SAMPLES = NUM_DEVICES


# reshape functions
def reshape_with_expand(
    last_hidden_state: torch.Tensor, num_labels: int
) -> torch.Tensor:
    return last_hidden_state.unsqueeze(1).expand(-1, num_labels, -1).flatten(0, 1)


# dummy dataset
class DummyDataset(Dataset):
    def __init__(self, num_samples=NUM_SAMPLES):
        self.num_samples = num_samples
        self.pixel_data = torch.randn(num_samples, INPUT_DIM)
        self.label_data = torch.randint(
            0, VOCAB_SIZE, (num_samples, NUM_LABELS, SEQ_LEN)
        )

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.pixel_data[idx], self.label_data[idx]


class PlToyModel(ptl.LightningModule):
    def __init__(self):
        super().__init__()
        self.final_gradients = {}

        self.encoder = nn.Linear(INPUT_DIM, HIDDEN_DIM)
        self.decoder = nn.Linear(HIDDEN_DIM, VOCAB_SIZE)

    def on_after_backward(self):
        if self.trainer.is_last_batch:
            print(f"[Rank {self.global_rank}] Capturing final gradients.", flush=True)
            for name, param in self.named_parameters():
                if param.grad is not None:
                    self.final_gradients[name] = param.grad.clone()

    def training_step(self, batch, batch_idx):
        print(f"GPU {self.global_rank}: batch_idx: {batch_idx}")
        pixel_values, labels = batch
        encoder_output = self.encoder(pixel_values)
        reshaped_encoder_output = reshape_with_expand(encoder_output, NUM_LABELS)
        input_to_decoder = reshaped_encoder_output.unsqueeze(1).repeat(1, SEQ_LEN, 1)
        logits = self.decoder(input_to_decoder)
        loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), labels.view(-1))
        self.log("train_loss", loss, sync_dist=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)


if __name__ == "__main__":
    shared_dataset = DummyDataset()

    ptl.seed_everything(SEED, workers=True)

    model_accumulate = PlToyModel()
    loader_accumulate = DataLoader(
        shared_dataset,
        batch_size=BATCH_SIZE,
        num_workers=2,
    )

    trainer_accumulate = ptl.Trainer(
        accelerator="gpu",
        strategy="ddp",
        max_epochs=1,
        enable_checkpointing=False,
        logger=False,
        devices=1,
        accumulate_grad_batches=NUM_SAMPLES,
    )
    print("---Starting 'accumulate'---")
    trainer_accumulate.fit(model_accumulate, loader_accumulate)

    accumulated_grads = model_accumulate.final_gradients

    # second run
    ptl.seed_everything(SEED, workers=True)

    model_ddp = PlToyModel()

    loader_ddp = DataLoader(
        shared_dataset,
        batch_size=BATCH_SIZE,
        num_workers=2,
    )

    trainer_ddp = ptl.Trainer(
        accelerator="gpu",
        strategy="ddp",
        max_epochs=1,
        enable_checkpointing=False,
        logger=False,
        devices=NUM_DEVICES,
    )
    print("---Starting 'ddp'---")
    trainer_ddp.fit(model_ddp, loader_ddp)
    ddp_grads = model_ddp.final_gradients

    time.sleep(1)

    if trainer_ddp.is_global_zero:
        print("\n" + "=" * 50)
        print("---Final gradient norm comparison ---")
        print("=" * 50)
        print(f"{'Parameter':<20} | {'Accumulated Norm':<15} | {'DDP norms':<15}")
        print("-" * 50)
        for name in accumulated_grads:
            accumulated_norm = accumulated_grads[name].norm().item()
            ddp_norm = ddp_grads[name].norm().item()
            print(f"{name:<20} | {accumulated_norm:<15.4f} | {ddp_norm:<15.4f}")
        print("=" * 50)

Why aren’t the gradients the same? In both cases, won’t it be (g1 + g2 + g3 + g4) / 4 (where g_i is the gradient as a result of the ith sample).

Also, why don’t I see something like GPU 1: batch_idx: 1 being printed in the ddp case?

Any help would be appreciated!

when you say:

Why aren’t the gradients the same?

What do you actually mean?

why don’t I see something like GPU 1: batch_idx: 1 being printed in the ddp case?

Sometimes, some trainer actually don’t show results from stdout rather than rank 0, not sure if this is the case here.

Apologies, I think we can close this issue. This was a misunderstanding on my part of how PyTorch Lightning’s trainer uses its own DDP strategy.

The gradients in the accumulate case and in the ddp case are the same. This can be verified with the following snippet:

import sys
import time

import pytorch_lightning as ptl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
from torch.utils.data import DataLoader, Dataset
import os

# Configs
BATCH_SIZE = 1
NUM_LABELS = 3
INPUT_DIM = 4
HIDDEN_DIM = 4
SEQ_LEN = 4
VOCAB_SIZE = 6
SEED = 42
NUM_DEVICES = 4
NUM_SAMPLES = NUM_DEVICES


# reshape functions
def reshape_with_expand(
    last_hidden_state: torch.Tensor, num_labels: int
) -> torch.Tensor:
    return last_hidden_state.unsqueeze(1).expand(-1, num_labels, -1).flatten(0, 1)


# dummy dataset
class DummyDataset(Dataset):
    def __init__(self, num_samples=NUM_SAMPLES):
        self.num_samples = num_samples
        self.pixel_data = torch.randn(num_samples, INPUT_DIM)
        self.label_data = torch.randint(
            0, VOCAB_SIZE, (num_samples, NUM_LABELS, SEQ_LEN)
        )

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.pixel_data[idx], self.label_data[idx]


class PlToyModel(ptl.LightningModule):
    def __init__(self, exp_name: str):
        super().__init__()
        self.exp_name = exp_name
        self.final_gradients = {}

        self.encoder = nn.Linear(INPUT_DIM, HIDDEN_DIM)
        self.decoder = nn.Linear(HIDDEN_DIM, VOCAB_SIZE)

    def on_before_optimizer_step(self, optimizer):
        if self.trainer.is_last_batch:
            print(f"[Rank {self.global_rank}] Capturing final gradients.")
            for name, param in self.named_parameters():
                if param.grad is not None:
                    self.final_gradients[name] = param.grad.clone()

    def training_step(self, batch, batch_idx):
        pixel_values, labels = batch
        print(f"LOCAL_RANK in os.environ: {'LOCAL_RANK' in os.environ}, name: {self.exp_name}")
        print(f"Getting LOCAL_RANK: {os.getenv('LOCAL_RANK')}")
        print(
            f"GPU {self.global_rank}: batch_idx: {batch_idx}, pixel_values: {pixel_values}, name: {self.exp_name}",
            file=sys.stderr,
            flush=True,
        )

        encoder_output = self.encoder(pixel_values)
        reshaped_encoder_output = reshape_with_expand(encoder_output, NUM_LABELS)
        input_to_decoder = reshaped_encoder_output.unsqueeze(1).repeat(1, SEQ_LEN, 1)
        logits = self.decoder(input_to_decoder)
        loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), labels.view(-1))
        self.log("train_loss", loss, sync_dist=True)
        return loss

    def on_train_batch_end(self, outputs, batch, batch_idx):
        if self.trainer.is_last_batch:
            print(f"[Rank {self.global_rank}] on_train_batch_end - Capturing gradients")
            # Check if gradients still exist here
            if self.encoder.weight.grad is not None:
                print(
                    f"[Rank {self.global_rank}] encoder.weight.grad[0,0] = {self.encoder.weight.grad[0,0].item():.6f}"
                )
                for name, param in self.named_parameters():
                    if param.grad is not None:
                        self.final_gradients[name] = param.grad.clone()
            else:
                print(f"[Rank {self.global_rank}] Gradients already cleared!")

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)


def check_state_dict_equality(state1, state2, name1="state1", name2="state2"):
    """Assert that two state dicts are identical"""
    assert set(state1.keys()) == set(
        state2.keys()
    ), f"Keys mismatch between {name1} and {name2}"

    for key in state1.keys():
        assert torch.allclose(
            state1[key], state2[key], atol=1e-7
        ), f"Parameter {key} mismatch between {name1} and {name2}"

    print(f"✓ {name1} and {name2} are identical")


if __name__ == "__main__":
    # Add this to ensure deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    ptl.seed_everything(SEED, workers=True)
    shared_dataset = DummyDataset()

    ptl.seed_everything(SEED, workers=True)
    model_accumulate = PlToyModel(exp_name="accumulate")
    accumulate_initial_state_dict = deepcopy(model_accumulate.state_dict())
    loader_accumulate = DataLoader(
        shared_dataset,
        batch_size=BATCH_SIZE,
        num_workers=0,
    )

    trainer_accumulate = ptl.Trainer(
        accelerator="gpu",
        max_epochs=1,
        enable_checkpointing=False,
        logger=False,
        devices=1,
        accumulate_grad_batches=NUM_SAMPLES,
        # enable_progress_bar=False,
        log_every_n_steps=1,
    )
    print("---Starting 'accumulate'---")
    trainer_accumulate.fit(model_accumulate, loader_accumulate)

    accumulated_grads = model_accumulate.final_gradients

    # second run
    ptl.seed_everything(SEED, workers=True)

    model_ddp = PlToyModel(exp_name="ddp")
    ddp_initial_state_dict = deepcopy(model_ddp.state_dict())

    check_state_dict_equality(
        accumulate_initial_state_dict,
        ddp_initial_state_dict,
        "accumulate_initial_state_dict",
        "ddp_initial_state_dict",
    )

    loader_ddp = DataLoader(
        shared_dataset,
        batch_size=BATCH_SIZE,
        num_workers=0,
    )

    trainer_ddp = ptl.Trainer(
        accelerator="gpu",
        strategy="ddp",
        max_epochs=1,
        enable_checkpointing=False,
        logger=False,
        devices=NUM_DEVICES,
        log_every_n_steps=1,
    )
    print("---Starting 'ddp'---")
    trainer_ddp.fit(model_ddp, loader_ddp)
    ddp_grads = model_ddp.final_gradients


    for i in range(NUM_DEVICES):
        time.sleep(1)
        if trainer_ddp.global_rank == i:
            print("\n" + "=" * 50)
            print(f"---Final gradient norm comparison, local_rank {i} ---")
            print("=" * 50)
            print(f"{'Parameter':<20} | {'Accumulated Norm':<15} | {'DDP norms':<15}")
            print("-" * 50)
            for name in accumulated_grads:
                accumulated_norm = accumulated_grads[name].norm().item()
                ddp_norm = ddp_grads[name].norm().item()
                print(f"{name:<20} | {accumulated_norm:<15.4f} | {ddp_norm:<15.4f}")
            print("=" * 50)

The output of this snippet will be something like this

==================================================
---Final gradient norm comparison, local_rank 0 ---
==================================================
Parameter            | Accumulated Norm | DDP norms      
--------------------------------------------------
encoder.weight       | 0.1327          | 0.1327         
encoder.bias         | 0.0649          | 0.0649         
decoder.weight       | 0.2206          | 0.2206         
decoder.bias         | 0.1619          | 0.1619         
==================================================

==================================================
---Final gradient norm comparison, local_rank 1 ---
==================================================
Parameter            | Accumulated Norm | DDP norms      
--------------------------------------------------
encoder.weight       | 0.1327          | 0.1327         
encoder.bias         | 0.0649          | 0.0649         
decoder.weight       | 0.2206          | 0.2206         
decoder.bias         | 0.1619          | 0.1619         
==================================================

==================================================
---Final gradient norm comparison, local_rank 2 ---
==================================================
Parameter            | Accumulated Norm | DDP norms      
--------------------------------------------------
encoder.weight       | 0.1327          | 0.1327         
encoder.bias         | 0.0649          | 0.0649         
decoder.weight       | 0.2206          | 0.2206         
decoder.bias         | 0.1619          | 0.1619         
==================================================

==================================================
---Final gradient norm comparison, local_rank 3 ---
==================================================
Parameter            | Accumulated Norm | DDP norms      
--------------------------------------------------
encoder.weight       | 0.1327          | 0.1327         
encoder.bias         | 0.0649          | 0.0649         
decoder.weight       | 0.2206          | 0.2206         
decoder.bias         | 0.1619          | 0.1619         
==================================================

What Pytorch Lightning does under the hood is described here: GPU training (Intermediate) — PyTorch Lightning 2.5.1.post0 documentation. It calls the script under the hood multiple times with the a change in environment variables, so we have to be careful about any randomization in the script.

# example for 3 GPUs DDP

MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=0 python my_file.py --accelerator 'gpu' --devices 3 --etc

MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=1 python my_file.py --accelerator 'gpu' --devices 3 --etc

MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=2 python my_file.py --accelerator 'gpu' --devices 3 --etc