Thank you so much! That at least helped me at some point but I am having now more troubles I am running into. So I try to run my code on an HPC Cluster with slurm. My whole code looks:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(512, 256, (4, 4), stride=(2, 2), padding=(0, 0)),
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.ConvTranspose2d(256, 128, (5, 5), (1, 1)),
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.ConvTranspose2d(128, 64, (2, 2), (2, 2)),
nn.BatchNorm2d(64),
nn.LeakyReLU(),
nn.ConvTranspose2d(64, 32, (2, 2), stride=(2, 2)),
nn.BatchNorm2d(32),
nn.LeakyReLU(),
nn.ConvTranspose2d(32, 2, (1, 1), stride=(1, 1)),
nn.BatchNorm2d(2),
nn.LeakyReLU(),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(2, 32, (1,1), stride=(1,1), padding=0, padding_mode='circular'),
nn.LayerNorm(32),
nn.LeakyReLU(0.2),
nn.Dropout(0.5),
nn.Conv2d(32, 64, (5,5), stride=(2,2), padding=2, padding_mode='circular'),
nn.LayerNorm(16),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, (5,5), stride=(2,2), padding=2, padding_mode='circular'),
nn.LayerNorm(8),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, (5,5), stride=(2,2), padding=2, padding_mode='circular'),
nn.LayerNorm(4),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 1, (4,4), stride=(1,1), padding=0, padding_mode='circular')
)
def forward(self, x):
return self.main(x)
if __name__ == '__main__':
# DDP setting
if "WORLD_SIZE" in os.environ:
world_size = int(os.environ["WORLD_SIZE"])
ngpus_per_node = torch.cuda.device_count()
if world_size > 1:
if 'SLURM_PROCID' in os.environ: # for slurm scheduler
rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count()
print("gpu", gpu, "rank", rank, "ngpus_per_node", ngpus_per_node)
dist.init_process_group(backend='nccl', init_method='env://',
world_size=world_size, rank=rank)
group_g = torch.distributed.new_group(ranks=[0, 1])
group_d = torch.distributed.new_group(ranks=[2, 3])
if gpu is not None:
torch.cuda.set_device(gpu)
with torch.cuda.device(0):
G = Generator().cuda()
D = Discriminator().cuda()
G = torch.nn.parallel.DistributedDataParallel(G, device_ids=[0,1], process_group=group_g,broadcast_buffers=True)
D = torch.nn.parallel.DistributedDataParallel(D, device_ids=[2,3],process_group=group_d,broadcast_buffers=True)