My goal is to get the byte object from the trained model (considering rank 0) in the main function. The model is created in the main function as well.
- Is there any problem in the code below?
- The name of the rank == 0 process is MainProcess, right?
- I don’t need to create a lock to synchronize access to the shared memory since I am only updating the value if rank equals to 0, right?
- Any optimizations while ensuring the goal is achieved?
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import TensorDataset
from torch.nn.parallel import DistributedDataParallel as DDP
import os
import io
def ddp_setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
dist.init_process_group(backend=dist.Backend.NCCL, init_method="env://", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
def train(rank, world_size, model, result_byte_object):
# Initialize the distributed training backend
ddp_setup(rank, world_size)
model = model.to(rank)
# Wrap the model with DDP
ddp_model = DDP(model, device_ids=[rank])
# Define your loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
# Create some dummy data for training
inputs = torch.randn(100, 10)
targets = torch.randn(100, 5)
# Create a DistributedSampler for shuffling the data
sampler = DistributedSampler(TensorDataset(inputs, targets))
# Create a DataLoader with the DistributedSampler
dataloader = torch.utils.data.DataLoader(
dataset=TensorDataset(inputs, targets),
batch_size=10,
shuffle=False,
pin_memory=True,
sampler=sampler
)
# Training loop
for epoch in range(10):
sampler.set_epoch(epoch) # Set the epoch for proper shuffling
optimizer.zero_grad()
for batch_inputs, batch_targets in dataloader:
batch_inputs = batch_inputs.to(rank)
batch_targets = batch_targets.to(rank)
outputs = ddp_model(batch_inputs)
loss = criterion(outputs, batch_targets)
loss.backward()
optimizer.step()
print('Rank', rank, 'Epoch', epoch, 'Loss', loss.item())
# Save the entire ddp_model to a byte object if rank == 0
if rank == 0:
buffer = io.BytesIO()
torch.save(ddp_model, buffer)
result_byte_object[:] = buffer.getvalue()
dist.destroy_process_group()
def main():
# Number of GPUs to use
num_gpus = torch.cuda.device_count()
# Determine the size of the shared byte object
model = Model()
buffer = io.BytesIO()
torch.save(model.cpu(), buffer)
# Create a shared byte object
result_byte_object = mp.Array('c', size_or_initializer=len(buffer.getvalue()))
os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
os.environ["NCCL_DEBUG"] = "DETAIL"
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
mp.spawn(train, args=(num_gpus, model, result_byte_object), nprocs=num_gpus)
# Access the byte object and save it to a file if rank == 0
model_byte_object = io.BytesIO()
if mp.current_process().name == 'MainProcess':
model_byte_object = io.BytesIO(result_byte_object)
return model_byte_object.getvalue()
if __name__ == '__main__':
main()