DDP fails when a HF module is present

I ran into a strange issue while trying to use DDP + logging, with a single main process handling the logging. I went a bit deeper and it appears to have something to do with the presence of huggingface modules, but I’m not entirely certain. I’ve distilled it into an MWE (pasted below) with some following commentary.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group, get_rank
from transformers import AutoModel


bert = AutoModel.from_pretrained("prajjwal1/bert-small")

class SampleModule(nn.Module):
    def __init__(self):
        # random model
        self.small_model = bert
        self.W = nn.Linear(128, 32, bias=False)

    def forward(self, sample, debug=False):
        # Fake loss
        return self.W(sample[0] + sample[1] + sample[2]).mean()

batch_size = 32

def get_loader():
    while True:
        yield (
            torch.randn(batch_size, 128),
            torch.randn(batch_size, 128),
            torch.randn(batch_size, 128),

rank = get_rank()
main_process = str(rank) == "0"
device = f"cuda:{rank}"

loader = iter(get_loader())
model = SampleModule().to(device)

for name, param in model.named_parameters():
    if name.startswith("small_model"):
        param.requires_grad = False

model = DDP(model, device_ids=[rank])
optimizer = optim.Adam(model.parameters())

# Block A
# batch = next(loader)
# batch = (batch[0].to(device), batch[1].to(device), batch[2].to(device))
# loss = model(batch)
# loss.backward()

for m, batch in enumerate(loader):
    print(f"Process {rank} on iter {m} is ready.")

    batch = (batch[0].to(device), batch[1].to(device), batch[2].to(device))
    loss = model(batch)

    # somehow torch.no_grad + the HF models cause problems...
    # but not if I "initialize" them by running a bakcwards pass frist?
    if main_process:
        print(f"logging on process {rank}")
        # can this cause problems? that batch is already on device but includes...
        with torch.no_grad():
            batch = next(loader)
            batch = (batch[0].to(device), batch[1].to(device), batch[2].to(device))
            test_loss = model(batch)


To run this code, I use a command like CUDA_VISIBLE_DEVICES=6,7 torchrun --nproc-per-node 2 mwe.py. If I run this, I get an error where the program will hang after a couple of iterations. Now, the program runs fine with any of the following changes:

  1. Uncomment block A, performing a single backwards pass before doing anything else
  2. Remove small_model from the SampleModule (this was meant to be a generic HuggingFace module, and not the specific ones I’m using for my task).
  3. Remove torch.no_grad. It still hangs, but it hangs at a later iteration.
  4. Allow both processes to access the “logging” section (so no if main_process).

I know I may get better luck on the HuggingFace forums, but this seems to be a strange interaction between DDP and HF. If anyone has any insight, would be much appreciated!

UPDATE: Though I’m not quite sure why, initializing DDP with find_unused_parameters=True (even without any unused parameters) seems to solve the problem, with some performance cost.