i want to use DDP to train model ,use num 6th,7th gpu.
this code core is :
import datetime
import torch.utils.data.dataloader as dataloader
import sys
import pdb
from termcolor import cprint
import torch
from matplotlib import cm
from tqdm import tqdm
import time
import shutil
import nibabel as nib
import pdb
import argparse
import os
from torch.utils.data.distributed import DistributedSampler
if __name__ == '__main__':
parser = argparse.ArgumentParser('setup record')
# default method l
parser.add_argument("--DDP", default=True)
# optimizer and scheuler
parser.add_argument("--lr", default=5e-5)
parser.add_argument("--opt", default='adam',
choices=['adam', 'sgd'])
parser.add_argument("--num_gpus", default=[6, 7])
parser.add_argument("--bs", default=4)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.num_gpus))
if args.DDP:
print('init ddp')
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.num_gpus))
torch.distributed.init_process_group(backend="nccl")
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
if args.seed is not None:
numpy.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.benchmark = True
net = build_model()
if len(args.num_gpus) > 1 and not args.DDP:
# pdb.set_trace()
# net = BalancedDataParallel(args.maingpu_bs, net, dim=0).cuda()
net = torch.nn.DataParallel(net).cuda()
print('net to multi-gpu')
if len(args.num_gpus) > 1 and args.DDP:
print('using DDP model')
net = torch.nn.parallel.DistributedDataParallel(net,
device_ids=[local_rank],
output_device=local_rank, find_unused_parameters=True)
dataset = build_datset()
if args.DDP:
print('using ddp dataloader')
train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.bs, shuffle=True,
num_workers=args.works, pin_memory=True,
sampler=DistributedSampler(dataset))
else:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.bs, shuffle=True,
num_workers=args.works, pin_memory=True)
""""""
"""
Training
"""
print('setting dataloader')
what should i do???
thank you