The only option seems to be to only log one process. This code runs fine:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tensorboardX import SummaryWriter
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.distributions import Laplace
class ToyNet(nn.Module):
def __init__(self):
super().__init__()
self.dens1 = nn.Linear(in_features=16, out_features=3)
def forward(self, x):
x = self.dens1(x)
x = Laplace(x, torch.tensor([1.0]))
return x
class RandomDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
sample = {'mod1': torch.rand(1, 16).float(),
'mod2': torch.rand(1, 16).float(),
'mod3': torch.rand(1, 16).float()}
label = torch.randint(0, 1, (3,)).float()
return sample, label
def __len__(self):
return 20
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class Experiment:
def __init__(self, distributed: bool, dir_logs: str):
self.logger = None
self.dir_logs = dir_logs
self.model = ToyNet()
self.rank = None
self.distributed = distributed
if distributed:
self.world_size = torch.cuda.device_count()
assert self.world_size > 1, 'More than 1 GPU need to be accessible to use distributed training'
else:
self.world_size = 1
def setup_logger(self):
# initialize summary writer
if not os.path.exists(self.dir_logs):
os.makedirs(self.dir_logs)
self.logger = SummaryWriter(self.dir_logs)
def train(exp: Experiment, rand_loader: DataLoader):
rank = exp.rank
model = exp.model.to(rank)
if exp.distributed:
model = DDP(exp.model, device_ids=[rank])
mse_loss = nn.MSELoss()
for step, (batch, label) in enumerate(rand_loader):
for modality in batch.keys():
label = label.to(rank)
batch = {k: v.to(rank) for k, v in batch.items()}
output = model(batch[modality]).mean
loss = mse_loss(output, label)
if exp.logger:
exp.logger.add_scalars(f'train/loss',
{'train_loss': loss.item()},
step)
loss.backward()
def validate(exp, rand_loader: DataLoader):
rank = exp.rank
model = exp.model.eval()
with torch.no_grad():
mse_loss = nn.MSELoss()
for step, (batch, label) in enumerate(rand_loader):
for modality in batch.keys():
label = label.to(rank)
batch = {k: v.to(rank) for k, v in batch.items()}
output = model(batch[modality]).mean
loss = mse_loss(output, label)
if exp.logger:
exp.logger.add_scalars(f'val/loss',
{'val_loss': loss.item()},
step)
def run_epochs(rank: any, exp: Experiment):
print(f"Running basic DDP example on rank {rank}.")
exp.rank = rank
if not exp.distributed or (rank % exp.world_size == 0):
print(f'setting up logger for rank {rank}')
exp.setup_logger()
if exp.distributed:
setup(rank, exp.world_size)
sampler = DistributedSampler(RandomDataset(), num_replicas=exp.world_size, rank=rank)
else:
sampler = None
rand_loader = DataLoader(dataset=RandomDataset(),
batch_size=8, shuffle=False, pin_memory=True, sampler=sampler, num_workers=0)
for epoch in range(5):
if exp.distributed:
sampler.set_epoch(epoch)
train(exp, rand_loader)
validate(exp, rand_loader)
if exp.distributed:
cleanup()
if exp.logger:
exp.logger.close()
print('done!')
if __name__ == '__main__':
log_dir = 'temp_dir'
use_distributed_training = True
ex = Experiment(use_distributed_training, log_dir)
if ex.distributed:
mp.spawn(run_epochs,
args=(ex,),
nprocs=ex.world_size,
join=True)
else:
run_epochs(torch.device('cuda'), ex)