Hello,
I would like to know if a big gap in accuracy is expected when using DDP.
When running my code for 3 epochs, I get:
- Without DDP: (1) 64.77% → (2) 72.21% → (3) 78.21%
- With DDP: (1) 49.34% → (2) 59.54% → (3) 65.89%
I saw on other posts that I should adapt the batch size and learning rate when using DDP (batch size x8 if I use 8 GPUs, and multiply lr by sqrt(8)), but I tried it and it only get worst (around 30% acc after 3 epochs).
Am I doing something wrong?
Here is a simplified version of my code that reproduces this behaviour:
# python3 DDP.py
# python3 -u -m torch.distributed.run --nproc_per_node=8 --nnodes=1 --master_port=2223 DDP.py
import os, random, math, logging
from typing import Dict, Any
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import numpy as np
from torch.optim.optimizer import Optimizer
from torchvision.transforms.functional import InterpolationMode
from torchvision import transforms, models
import torchvision.transforms.functional as tF
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision.datasets import ImageFolder, CIFAR10
cudnn.enabled = False
logging.basicConfig(format="[%(levelname)s] %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
# parameters
batch_size = 64
lr = 0.001
wd = 0
path_data = "./data/cifar10"
epochs = 3
# Reproducibility
seed = 0
torch.manual_seed(seed) # for cpu
torch.cuda.manual_seed(seed) # for gpu
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
cudnn.benchmark = False
# distributed
distributed = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) > 1
if distributed:
torch.distributed.init_process_group(backend='nccl', init_method='env://')
world_size = torch.distributed.get_world_size()
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(i) for i in range(world_size))
rank = torch.distributed.get_rank()
device = torch.device(f'cuda:{rank}')
torch.cuda.set_device(rank)
logger.info(f'Working in distributed mode. Process {rank}/{world_size-1}.')
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
rank = 0
device = torch.device("cuda:0")
world_size = 1
logger.info(f'Working on one GPU.')
if rank != 0:
logger.disabled = True
# ==================== LOAD DATA
def fast_collate(batch):
assert isinstance(batch[0], tuple)
batch_size = len(batch)
if isinstance(batch[0][0], tuple):
inner_tuple_size = len(batch[0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
for j in range(inner_tuple_size):
targets[i + j * batch_size] = batch[i][1]
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
return tensor, targets
elif isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i].copy_(batch[i][0])
return tensor, targets
else:
assert False
class PrefetchLoader:
def __init__(self, loader, mean, std, channels=3):
normalization_shape = (1, channels, 1, 1)
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)
def __iter__(self):
stream = torch.cuda.Stream()
first = True
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.type(torch.int64).cuda(non_blocking=True)
next_input = next_input.float().sub_(self.mean).div_(self.std)
if not first:
yield input, target
else:
first = False
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
yield input, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
# cifar10
path_data_train = os.path.join(path_data, 'train')
path_data_val = os.path.join(path_data, 'val')
data_train = CIFAR10(path_data, train=True)
data_val = CIFAR10(path_data, train=False)
data_train.transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(p=0.5), ToNumpy()])
data_val.transform = transforms.Compose([ToNumpy()])
mean=(0.4914, 0.4822, 0.4465)
std=(0.2023, 0.1994, 0.2010)
sampler_train = None
sampler_val = None
if distributed:
sampler_train = torch.utils.data.distributed.DistributedSampler(data_train, num_replicas=world_size, rank=rank)
sampler_val = torch.utils.data.distributed.DistributedSampler(data_val, num_replicas=world_size, rank=rank)
loader_train = torch.utils.data.DataLoader(data_train, batch_size=batch_size, shuffle=False, num_workers=16, sampler=sampler_train, collate_fn=fast_collate)
loader_val = torch.utils.data.DataLoader(data_val, batch_size=batch_size, shuffle=False, num_workers=16, sampler=sampler_val, collate_fn=fast_collate)
loader_train = PrefetchLoader(loader_train, mean=mean, std=std)
loader_val = PrefetchLoader(loader_val, mean=mean, std=std)
# ==================== LOAD MODEL
class VGG(nn.Module):
def __init__(self, features):
super(VGG, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(512, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 512),
nn.ReLU(True),
nn.Linear(512, 10),
)
# Initialize weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
m.bias.data.zero_()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def make_layers(cfg=[64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
model = VGG(make_layers()).to(device)
# ==================== SCHEDULER
class Scheduler:
""" Parameter Scheduler Base Class
A scheduler base class that can be used to schedule any optimizer parameter groups.
Unlike the builtin PyTorch schedulers, this is intended to be consistently called
* At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
* At the END of each optimizer update, after incrementing the update count, to calculate next update's value
The schedulers built on this should try to remain as stateless as possible (for simplicity).
This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
and -1 values for special behaviour. All epoch and update counts must be tracked in the training
code and explicitly passed in to the schedulers on the corresponding step or step_update call.
Based on ideas from:
* https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
param_group_field: str,
noise_range_t=None,
noise_type='normal',
noise_pct=0.67,
noise_std=1.0,
noise_seed=None,
initialize: bool = True) -> None:
self.optimizer = optimizer
self.param_group_field = param_group_field
self._initial_param_group_field = f"initial_{param_group_field}"
if initialize:
for i, group in enumerate(self.optimizer.param_groups):
if param_group_field not in group:
raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
group.setdefault(self._initial_param_group_field, group[param_group_field])
else:
for i, group in enumerate(self.optimizer.param_groups):
if self._initial_param_group_field not in group:
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
self.noise_range_t = noise_range_t
self.noise_pct = noise_pct
self.noise_type = noise_type
self.noise_std = noise_std
self.noise_seed = noise_seed if noise_seed is not None else 42
self.update_groups(self.base_values)
def state_dict(self) -> Dict[str, Any]:
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.__dict__.update(state_dict)
def get_epoch_values(self, epoch: int):
return None
def get_update_values(self, num_updates: int):
return None
def step(self, epoch: int, metric=None) -> None:
values = self.get_epoch_values(epoch)
if values is not None:
values = self._add_noise(values, epoch)
self.update_groups(values)
def step_update(self, num_updates: int):
values = self.get_update_values(num_updates)
if values is not None:
values = self._add_noise(values, num_updates)
self.update_groups(values)
def update_groups(self, values):
if not isinstance(values, (list, tuple)):
values = [values] * len(self.optimizer.param_groups)
for param_group, value in zip(self.optimizer.param_groups, values):
if 'lr_scale' in param_group:
param_group[self.param_group_field] = value * param_group['lr_scale']
else:
param_group[self.param_group_field] = value
def _add_noise(self, lrs, t):
if self._is_apply_noise(t):
noise = self._calculate_noise(t)
lrs = [v + v * noise for v in lrs]
return lrs
def _is_apply_noise(self, t) -> bool:
"""Return True if scheduler in noise range."""
apply_noise = False
if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else:
apply_noise = t >= self.noise_range_t
return apply_noise
def _calculate_noise(self, t) -> float:
g = torch.Generator()
g.manual_seed(self.noise_seed + t)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
return noise
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
return noise
class CosineLRScheduler(Scheduler):
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lr_min: float = 0.,
cycle_mul: float = 1.,
cycle_decay: float = 1.,
cycle_limit: int = 1,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
k_decay=1.0,
initialize=True) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
assert t_initial > 0
assert lr_min >= 0
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
_logger.warning("Cosine annealing scheduler will have no effect on the learning "
"rate since t_initial = t_mul = eta_mul = 1.")
self.t_initial = t_initial
self.lr_min = lr_min
self.cycle_mul = cycle_mul
self.cycle_decay = cycle_decay
self.cycle_limit = cycle_limit
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
self.k_decay = k_decay
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
if self.warmup_prefix:
t = t - self.warmup_t
if self.cycle_mul != 1:
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
t_i = self.cycle_mul ** i * self.t_initial
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
else:
i = t // self.t_initial
t_i = self.t_initial
t_curr = t - (self.t_initial * i)
gamma = self.cycle_decay ** i
lr_max_values = [v * gamma for v in self.base_values]
k = self.k_decay
if i < self.cycle_limit:
lrs = [
self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k))
for lr_max in lr_max_values
]
else:
lrs = [self.lr_min for _ in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles
else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
# ==================== OPTIMIZER
class Ranger(Optimizer):
def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95,0.99), eps=1e-6, weight_decay=0):
#parameter checks
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
if not lr > 0:
raise ValueError(f'Invalid Learning Rate: {lr}')
if not eps > 0:
raise ValueError(f'Invalid eps: {eps}')
#prep defaults and init torch.optim base
defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
super().__init__(params,defaults)
#adjustable threshold
self.N_sma_threshhold = N_sma_threshhold
#look ahead params
self.alpha = alpha
self.k = k
#radam buffer for state
self.radam_buffer = [[None,None,None] for ind in range(10)]
def __setstate__(self, state):
print("set state called")
super(Ranger, self).__setstate__(state)
def step(self, closure=None):
loss = None
#Evaluate averages and grad, update param tensors
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Ranger optimizer does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p] #get state dict for this param
if len(state) == 0: #if first time to run...init dictionary with our desired entries
#if self.first_run_check==0:
#self.first_run_check=1
#print("Initializing slow buffer...should not see this at load from saved model!")
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
#look ahead weight storage now in state dict
state['slow_buffer'] = torch.empty_like(p.data)
state['slow_buffer'].copy_(p.data)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
#begin computations
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
#compute variance mov avg
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2)
#compute mean moving avg
exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1)
state['step'] += 1
buffered = self.radam_buffer[int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
if N_sma > self.N_sma_threshhold:
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
step_size = 1.0 / (1 - beta1 ** state['step'])
buffered[2] = step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(p_data_fp32, alpha = -group['weight_decay'] * group['lr'])
if N_sma > self.N_sma_threshhold:
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(exp_avg, denom, value = -step_size * group['lr'])
else:
p_data_fp32.add_(exp_avg, alpha = -step_size * group['lr'])
p.data.copy_(p_data_fp32)
#integrated look ahead...
#we do it at the param level instead of group level
if state['step'] % group['k'] == 0:
slow_p = state['slow_buffer'] #get access to slow param tensor
slow_p.add_(p.data - slow_p, alpha = self.alpha) #(fast weights - slow weights) * alpha
p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor
return loss
# ==================== TRAIN / VALIDATE
def param_groups_weight_decay(
model: nn.Module,
weight_decay=1e-5,
no_weight_decay_list=()
):
no_weight_decay_list = set(no_weight_decay_list)
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
parameters = param_groups_weight_decay(model, wd)
opti = Ranger(parameters, lr=lr)
sched = CosineLRScheduler(opti, t_initial=epochs)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1).to(device)
scaler = torch.cuda.amp.GradScaler() # native amp
def reduce_tensor(tensor, n):
rt = tensor.clone()
torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
rt /= n
return rt
class AverageMeter(object):
def __init__(self):
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
batch_size = target.size(0)
_, pred = output.topk(1, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
return correct[:1].reshape(-1).float().sum(0) * 100. / batch_size
def train_one_epoch(epoch):
losses = AverageMeter()
model.train()
lrl = [param_group['lr'] for param_group in opti.param_groups]
cur_lr = sum(lrl) / len(lrl)
if distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
logger.info('-'*15)
logger.info(f'Epoch[{epoch}]')
logger.info(f'learning rate: {cur_lr:.7f}')
num_updates = epoch * len(loader_train)
for i, (images, target) in enumerate(tqdm(loader_train, leave=False)):
images = images.to(device)
target = target.to(device)
opti.zero_grad()
with autocast():
logits = model(images)
loss = criterion(logits, target)
# record loss
if distributed:
reduced_loss = reduce_tensor(loss.data, world_size)
losses.update(reduced_loss.item(), images.size(0))
else:
losses.update(loss.item(), images.size(0)) # accumulated loss
# logger.disabled = False
scaler.scale(loss).backward()
scaler.step(opti)
# if distributed and i == len(loader_train)-1:
# logger.disabled = False
# logger.info(model.module.classifier[1].weight.grad)
# if rank != 0:
# logger.disabled = True
scaler.update()
num_updates += 1
sched.step_update(num_updates=num_updates)
if distributed:
torch.distributed.barrier()
logger.info(f'training loss: {losses.avg:.7f}')
sched.step(epoch+1)
def validate():
top1 = AverageMeter()
model.eval()
with torch.no_grad():
for i, (images, target) in enumerate(loader_val):
images = images.to(device)
target = target.to(device)
with autocast():
logits = model(images)
# measure accuracy
pred1 = accuracy(logits, target)
# if distributed and i == len(loader_val)-1:
# logger.disabled = False
# logger.info(pred1)
# if rank != 0:
# logger.disabled = True
if distributed:
pred1 = reduce_tensor(pred1, world_size)
top1.update(pred1.item(), images.size(0))
logger.info(f'validation acc: {top1.avg:.3f}')
return top1.avg
# setup distributed training
if distributed:
logger.info("Using DistributedDataParallel.")
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, broadcast_buffers=True, device_ids=[rank])
best_acc = 0
epoch_best_acc = 0
epoch = 0
while epoch < epochs:
train_one_epoch(epoch)
val_acc = validate()
epoch += 1