The code below works on Terminal but not on Jupyter Notebook
import os
from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.ones(200, 10))
labels = torch.randn(200, 5).to(rank)
loss = loss_fn(outputs, labels)
print("Loss is ",loss.item())
loss.backward()
optimizer.step()
cleanup()
if __name__ == '__main__':
world_size = 2
print("We have available ", torch.cuda.device_count(), "GPUs! but using ",world_size," GPUs")
#########################################################
mp.spawn(demo_basic, args=(world_size), nprocs=world_size, join=True)
#########################################################
Terminal output:
We have available 2 GPUs! but using 2 GPUs
Running basic DDP example on rank 1.
Running basic DDP example on rank 0.
Loss is 1.0888941287994385
Loss is 1.0354920625686646
Jupyter Notebook Output:
We have available 2 GPUs! but using 2 GPUs
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
<ipython-input-2-52a9f6d32955> in <module>
68
69 #########################################################
---> 70 mp.spawn(demo_basic, args=(world_size), nprocs=world_size, join=True)
71 #########################################################
~/.conda/envs/praveen_tf/lib/python3.6/site-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon, start_method)
198 ' torch.multiprocessing.start_process(...)' % start_method)
199 warnings.warn(msg)
--> 200 return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
~/.conda/envs/praveen_tf/lib/python3.6/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
156
157 # Loop on join until it returns True or raises an exception.
--> 158 while not context.join():
159 pass
160
~/.conda/envs/praveen_tf/lib/python3.6/site-packages/torch/multiprocessing/spawn.py in join(self, timeout)
111 raise Exception(
112 "process %d terminated with exit code %d" %
--> 113 (error_index, exitcode)
114 )
115
Exception: process 1 terminated with exit code 1
Can anyone explain why is this?