Hi,
to speed up my training I was looking into pytorches DistributedDataParallel, since the docs state that DataParallel has a lot of overhead which reduces the speed. I tested a couple of hyperparameters and found weird behavior, which left me wondering if I oversaw something.
I am running on a linux-64 bit cluster node with 64 cores, 350+ GB of ram and 4 Nvidia Tesla V100 (16 GB). I tested the stable 1.5.0 version and nightly 1.7.0.dev20200720 because I wanted to use automated mixed precision as another speed up.
The model I was testing which is a BERT model from the transformer library, with a single linear layer and a BCEWithLogitsLoss.
I tested three different training modes (all single machine): 1. a single GPU. 2. multi GPU with DataParallel. 3. multi GPU with DistributedDataParallel.
Then I tested memory_pin, num_workers for the dataloader and mixed precision if possible.
code for reference:
import os
from datetime import datetime
from argparse import ArgumentParser
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from transformers import AdamW, BertConfig
from prediction_module import path_vocab, path_raw_uniprot
from prediction_module.protein_datasets import ProteinBertLabeledDataset
from prediction_module.helpers import get_logger
logger = get_logger(__file__)
def train_model_dp(dataset, batch_size=4, n_steps=1000, num_workers=0, parallel=True, mixed_pres=False, pin_memory=False):
from prediction_module.protein_models import ProteinBertForMultiLabel
if mixed_pres:
ProteinBertForMultiLabel.forward = torch.cuda.amp.autocast()(ProteinBertForMultiLabel.forward)
torch.manual_seed(0)
config = BertConfig(
vocab_size=dataset.tokenizer.vocab_size,
num_labels=dataset.num_labels,
max_position_embeddings=dataset.input_size
)
model = ProteinBertForMultiLabel(config)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1 and parallel:
batch_size = batch_size * torch.cuda.device_count()
model = torch.nn.DataParallel(model)
logger.debug(f"testing: {batch_size=} {num_workers=} {parallel=} {mixed_pres=} {pin_memory=}")
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory)
model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=1e-5) # create optimizer
if mixed_pres:
scaler = torch.cuda.amp.GradScaler()
start = datetime.now()
for epoch in range(1): # loop over the dataset multiple times
for i, inputs in enumerate(dataloader):
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(device, non_blocking=True)
# zero the parameter gradients
optimizer.zero_grad()
if mixed_pres:
with torch.cuda.amp.autocast():
outputs = model(**inputs)
loss = outputs[0]
loss = loss.mean()
# Backward and optimize
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
# forward + backward + optimize
outputs = model(**inputs)
loss = outputs[0]
loss = loss.mean()
loss.backward()
optimizer.step()
if i >= n_steps:
break
logger.debug("Training complete in: %s. normalized by batch size: %s", str(datetime.now() - start), str((datetime.now() - start) / batch_size))
def train_start(rank, world_size, batch_size=4, mixed_pres=False, pin_memory=True, num_workers=0, n_steps=1000, epochs=1):
from prediction_module.protein_models import ProteinBertForMultiLabel
if mixed_pres:
ProteinBertForMultiLabel.forward = torch.cuda.amp.autocast()(ProteinBertForMultiLabel.forward)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.manual_seed(0)
torch.cuda.set_device(rank)
dataset = ProteinBertLabeledDataset(
vocab=path_vocab,
csv_path=os.path.join(path_raw_uniprot, "raw_data.csv"),
h5_path=os.path.join(path_raw_uniprot, "metled_go_data.h5")
)
config = BertConfig(
vocab_size=dataset.tokenizer.vocab_size,
num_labels=dataset.num_labels,
max_position_embeddings=dataset.input_size
)
model = ProteinBertForMultiLabel(config)
model.cuda(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank,
find_unused_parameters=True)
model.train()
optimizer = AdamW(model.parameters(), lr=1e-5) # create optimizer
if mixed_pres:
scaler = torch.cuda.amp.GradScaler()
# Data loading code
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
train_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
sampler=train_sampler,
collate_fn=dataset.collate_fn)
start = datetime.now()
for epoch in range(epochs):
for i, inputs in enumerate(train_loader):
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.cuda(rank, non_blocking=True)
optimizer.zero_grad()
if mixed_pres:
with torch.cuda.amp.autocast():
outputs = model(**inputs)
loss = outputs[0]
loss = loss.mean()
# Backward and optimize
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
outputs = model(**inputs)
loss = outputs[0]
loss = loss.mean()
loss.backward()
optimizer.step()
if i >= n_steps:
break
if rank == 0:
logger.debug("Training complete in: %s", str(datetime.now() - start))
dist.destroy_process_group()
def train_model_ddp(world_size=4, mixed_pres=False, batch_size=4, pin_memory=False, num_workers=0, n_steps=1000):
logger.debug(f"testing: {batch_size=} {num_workers=} {mixed_pres=} {pin_memory=}")
mp.spawn(train_start,
args=(world_size, batch_size, mixed_pres, pin_memory, num_workers, n_steps),
nprocs=world_size,
join=True)
if __name__ == "__main__":
try:
from torch.cuda.amp import autocast
mp_avail = True
except ImportError:
mp_avail = False
parser = ArgumentParser()
parser.add_argument("--test-dp", dest="test_dp", default=False, const=True, nargs="?")
parser.add_argument("--test-ddp", dest="test_ddp", default=False, const=True, nargs="?")
args = parser.parse_args()
args_dict = vars(args)
logger.debug("torch version: %s", torch.__version__)
if args_dict["test_dp"]:
dataset = ProteinBertLabeledDataset(
vocab=path_vocab,
csv_path=os.path.join(path_raw_uniprot, "raw_data.csv"),
h5_path=os.path.join(path_raw_uniprot, "metled_go_data.h5")
)
logger.debug("testing single gpu")
train_model_dp(dataset, parallel=False)
train_model_dp(dataset, parallel=False)
if mp_avail:
train_model_dp(dataset, parallel=False, mixed_pres=True)
train_model_dp(dataset, parallel=False, num_workers=8)
train_model_dp(dataset, parallel=False, num_workers=16)
train_model_dp(dataset, parallel=False, pin_memory=True)
logger.debug("testing dp")
train_model_dp(dataset)
train_model_dp(dataset, num_workers=8)
train_model_dp(dataset, num_workers=16)
train_model_dp(dataset, pin_memory=True)
if mp_avail:
train_model_dp(dataset, mixed_pres=True)
if args_dict["test_ddp"]:
logger.debug("testing ddp")
train_model_ddp()
train_model_ddp(pin_memory=True)
train_model_ddp(num_workers=8)
train_model_ddp(num_workers=16)
if mp_avail:
train_model_ddp(mixed_pres=True)
The results:
testing single gpu
torch version: 1.5.0
testing: batch_size=4 num_workers=0 parallel=False mixed_pres=False pin_memory=False
Training complete in: 0:02:48.407579. normalized by batch size: 0:00:42.101900
testing: batch_size=4 num_workers=0 parallel=False mixed_pres=False pin_memory=False
Training complete in: 0:02:47.146963. normalized by batch size: 0:00:41.786745
testing: batch_size=4 num_workers=8 parallel=False mixed_pres=False pin_memory=False
Training complete in: 0:02:49.422436. normalized by batch size: 0:00:42.355613
testing: batch_size=4 num_workers=16 parallel=False mixed_pres=False pin_memory=False
Training complete in: 0:02:50.284026. normalized by batch size: 0:00:42.571010
testing: batch_size=4 num_workers=0 parallel=False mixed_pres=False pin_memory=True
Training complete in: 0:02:47.878925. normalized by batch size: 0:00:41.969736
testing dp
testing: batch_size=16 num_workers=0 parallel=True mixed_pres=False pin_memory=False
Training complete in: 0:05:32.129513. normalized by batch size: 0:00:20.758095
testing: batch_size=16 num_workers=8 parallel=True mixed_pres=False pin_memory=False
Training complete in: 0:05:28.702392. normalized by batch size: 0:00:20.543900
testing: batch_size=16 num_workers=16 parallel=True mixed_pres=False pin_memory=False
Training complete in: 0:05:29.794879. normalized by batch size: 0:00:20.612181
testing: batch_size=16 num_workers=0 parallel=True mixed_pres=False pin_memory=True
Training complete in: 0:05:24.955569. normalized by batch size: 0:00:20.309724
torch version: 1.7.0.dev20200720
testing single gpu
testing: batch_size=4 num_workers=0 parallel=False mixed_pres=False pin_memory=False
Training complete in: 0:02:50.061025. normalized by batch size: 0:00:42.515261
testing: batch_size=4 num_workers=0 parallel=False mixed_pres=False pin_memory=False
Training complete in: 0:02:48.032688. normalized by batch size: 0:00:42.008176
testing: batch_size=4 num_workers=0 parallel=False mixed_pres=True pin_memory=False
Training complete in: 0:01:54.984463. normalized by batch size: 0:00:28.746120
testing: batch_size=4 num_workers=8 parallel=False mixed_pres=False pin_memory=False
Training complete in: 0:02:50.344483. normalized by batch size: 0:00:42.586124
testing: batch_size=4 num_workers=16 parallel=False mixed_pres=False pin_memory=False
Training complete in: 0:02:51.148356. normalized by batch size: 0:00:42.787092
testing: batch_size=4 num_workers=0 parallel=False mixed_pres=False pin_memory=True
Training complete in: 0:02:48.677086. normalized by batch size: 0:00:42.169276
testing dp
testing: batch_size=16 num_workers=0 parallel=True mixed_pres=False pin_memory=False
Training complete in: 0:05:30.977989. normalized by batch size: 0:00:20.686125
testing: batch_size=16 num_workers=8 parallel=True mixed_pres=False pin_memory=False
Training complete in: 0:05:26.893676. normalized by batch size: 0:00:20.430856
testing: batch_size=16 num_workers=16 parallel=True mixed_pres=False pin_memory=False
Training complete in: 0:05:28.139827. normalized by batch size: 0:00:20.508740
testing: batch_size=16 num_workers=0 parallel=True mixed_pres=False pin_memory=True
Training complete in: 0:05:22.767213. normalized by batch size: 0:00:20.172952
testing: batch_size=16 num_workers=0 parallel=True mixed_pres=True pin_memory=False
Training complete in: 0:04:26.452442. normalized by batch size: 0:00:16.653278
torch version: 1.5.0
testing ddp
testing: batch_size=4 num_workers=0 mixed_pres=False pin_memory=False
Training complete in: 0:04:59.752312
testing: batch_size=4 num_workers=0 mixed_pres=False pin_memory=True
Training complete in: 0:04:59.236787
testing: batch_size=4 num_workers=8 mixed_pres=False pin_memory=False
Training complete in: 0:12:16.935697
torch version: 1.7.0.dev20200720
testing ddp
testing: batch_size=4 num_workers=0 mixed_pres=False pin_memory=False
Training complete in: 0:05:02.979028
testing: batch_size=4 num_workers=0 mixed_pres=False pin_memory=True
Training complete in: 0:05:03.088308
testing: batch_size=4 num_workers=8 mixed_pres=False pin_memory=False
Training complete in: 0:11:05.255453
testing: batch_size=4 num_workers=0 mixed_pres=True pin_memory=False
Training complete in: 0:05:10.881854
My interpretation
Training on a single GPU takes about 2:50 minutes for all parameters except mixed precision, which increases speed to around 2 minutes.
So perfect parallelization would mean that the same time would be required with 4 GPUs if every single GPU gets a mini-batch with size 4, correct?
DataParallel seems to behave very similar to the hyperparameters, training takes around 5:25 mintues, except for mixed precision, which decreases it to 4:25 minutes.
Now to DistributedDataParallel:
Increasing the number of workers seems to slow down training by a lot.
Mixed precision has no effect on training speed (even though I observed on the GPUs that the required ram was decreased compared to not using it, and similar to the ram required for the mixed precision during DataParallel).
This is the first time using pytorch, so if I oversaw anything please let me know. Otherwise I would be interested what caused these effects.