Can someone please verify if the below DDP implementation sample is correct? I have a feeling that data is somehow being duplicated and duplicate training is done on multiple GPUs. Is there a way to know how much data is distributed between multiple GPUs?
def setup(rank, world_size, port):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = port
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def run_parallel(rank, arg1, arg2, world_size):
import time
start = time.process_time()
logger = create_logger(args)
globals()['logger'] = logger
logger.info('Starting pooling')
p = Pool()
logging.info(f"Running training on {rank}.")
logging.info(args)
logging.info(f"DEVICE COUNT {torch.cuda.device_count()}")
logging.info(f'GPU: {args.gpu}')
setup(rank, world_size, args.port)
writer_logdir = os.path.join(args.logfile, 'tb')#'_'.join(['vm','un'+str(args.uncert), str(args.weight), args.logfile])
tb = SummaryWriter(log_dir= writer_logdir)
##BUILD MODEL##
model = RegNet(args).to(rank)
ddp_model = DDP(model, device_ids=[rank], output_device = rank)
train = TrainModel(ddp_model, train_dataloader, test_dataloader, args, NUM_CLASS, tb=tb, save_version=args.save_version)
train.run()
if tb is not None:
tb.close()
cleanup()
class TrainModel():
...
def run(self):#device=torch.device("cuda:0")
if self.args.sgd:
optimizer = SGD(self.model.parameters(), lr = self.args.lr)
else:
optimizer = Adam(self.model.parameters(),lr = self.args.lr)
for epoch in range(self.args.epoch):
self.train_epoch(optimizer, scheduler, epoch)
self.cur_epoch = epoch
def train_epoch(self, optimizer, scheduler, epoch):
self.model.train()
idx = 0
for n_iter, samples in enumerate(self.train_dataloader):
fixed, fixed_label, moving, moving_label, fixed_nopad, seg_fname = self.data_extract(samples, self.model.device)
if not self.save_seg:
seg_fname = None
self.global_idx += 1
self.cur_idx = idx
logging.info(f'iteration={idx}/{len(self.train_dataloader)}')
idx+=1
loss, trdice = self.trainIter(fixed, moving, fixed_label, moving_label, fixed_nopad=fixed_nopad, seg_f=seg_fname)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if __name__ == "__main__":
args = get_args()
torch.multiprocessing.set_sharing_strategy('file_system') # handle large number of files. Increase the number of open files.
torch.cuda.empty_cache()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
gpu = [int(i) for i in range(torch.cuda.device_count())]
world_size = len(GPU)
mp.spawn(run_parallel,
args=(args, world_size),
nprocs=world_size,
join=True)