How can I get the trained DDP model in another function?

My goal is to get the byte object from the trained model (considering rank 0) in the main function. The model is created in the main function as well.

  1. Is there any problem in the code below?
  2. The name of the rank == 0 process is MainProcess, right?
  3. I don’t need to create a lock to synchronize access to the shared memory since I am only updating the value if rank equals to 0, right?
  4. Any optimizations while ensuring the goal is achieved?
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import TensorDataset
from torch.nn.parallel import DistributedDataParallel as DDP
import os
import io


def ddp_setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    dist.init_process_group(backend=dist.Backend.NCCL, init_method="env://", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        return self.fc(x)


def train(rank, world_size, model, result_byte_object):
    # Initialize the distributed training backend
    ddp_setup(rank, world_size)

    model = model.to(rank)

    # Wrap the model with DDP
    ddp_model = DDP(model, device_ids=[rank])

    # Define your loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)

    # Create some dummy data for training
    inputs = torch.randn(100, 10)
    targets = torch.randn(100, 5)

    # Create a DistributedSampler for shuffling the data
    sampler = DistributedSampler(TensorDataset(inputs, targets))

    # Create a DataLoader with the DistributedSampler
    dataloader = torch.utils.data.DataLoader(
        dataset=TensorDataset(inputs, targets),
        batch_size=10,
        shuffle=False,
        pin_memory=True,
        sampler=sampler
    )

    # Training loop
    for epoch in range(10):
        sampler.set_epoch(epoch)  # Set the epoch for proper shuffling

        optimizer.zero_grad()

        for batch_inputs, batch_targets in dataloader:
            batch_inputs = batch_inputs.to(rank)
            batch_targets = batch_targets.to(rank)

            outputs = ddp_model(batch_inputs)
            loss = criterion(outputs, batch_targets)
            loss.backward()
            optimizer.step()

        print('Rank', rank, 'Epoch', epoch, 'Loss', loss.item())

    # Save the entire ddp_model to a byte object if rank == 0
    if rank == 0:
        buffer = io.BytesIO()
        torch.save(ddp_model, buffer)

        result_byte_object[:] = buffer.getvalue()

    dist.destroy_process_group()


def main():
    # Number of GPUs to use
    num_gpus = torch.cuda.device_count()

    # Determine the size of the shared byte object
    model = Model()
    buffer = io.BytesIO()
    torch.save(model.cpu(), buffer)

    # Create a shared byte object
    result_byte_object = mp.Array('c', size_or_initializer=len(buffer.getvalue()))

    os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
    os.environ["NCCL_DEBUG"] = "DETAIL"
    os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"  # set to DETAIL for runtime logging.

    mp.spawn(train, args=(num_gpus, model, result_byte_object), nprocs=num_gpus)

    # Access the byte object and save it to a file if rank == 0
    model_byte_object = io.BytesIO()
    if mp.current_process().name == 'MainProcess':
        model_byte_object = io.BytesIO(result_byte_object)

    return model_byte_object.getvalue()


if __name__ == '__main__':
    main()

We can close this post!

I was able to solve my problem.

Just to keep it registered, some changes/optimizations that were done:

  1. I used mp.list in the main function, since its size can be dynamically changed. Then, I initialize mp.array at the end of main() passing this list as parameter.
  2. Instead of using mp.spawn I am using mp.Process so I can get more control over the processes and can wait for all of them to finish.
  3. No need to get current_process, since all the processes were finished at this point and I am only getting the model byte for GPU 0.
  4. When saving the model I am doing ddp_model.module, so I don’t need to wrap it with DDP when deserializing it again in the future.