Thank you for your endless help, @ptrblck .
I’m referencing the way of using DDP from this repository
And just in case, let me know you that:
- my model contains BiLSTM
- I’m also using a quantizer (but, skipped in the abstract codes below).
Abstract of my work’s __main__.py
.
The order of instance initialization is identical with my source code.
import torch
from torch import distributed
from torch.nn.parallel import DistributedDataParallel
def train(args):
device = torch.device(f'cuda:{args.rank}')
torch.cuda.set_device(device=device)
distributed.init_process_group(
backend='nccl',
init_method=f'tcp://{args.master_url}',
world_size=args.world_size,
rank=args.rank,
)
# Instantiate my torch.utils.data.Dataset object
train_dataset = MyDataset()
# Instantiate my model
model = MyModule()
model.to(device)
# Augmentation modules which has no parameters.
augmentations = [
Augmentation1(),
Augmentation2(),
Augmentation3(),
]
augmentations = torch.nn.Sequential(*augmentations)
augmentations.to(device)
# And instantiate etc.
optimizer = ...
criterion = ...
# I've checked the memory usage here, and it says 1140.xx MiB.
# Wrap the model with DDP.
model = DistributedDataParallel(
module=model,
device_ids=[torch.cuda.current_device()],
output_device=torch.cuda.current_device(),
)
# I've checked the memory usage here again, and it says 2281.xx MiB.
# Instantiate my Trainer class, whose abstract is below.
trainer = Trainer(
model=model,
dataset=train_dataset,
criterion=criterion,
optimizer=optimizer,
batch_size=batch_size,
num_workers=args.num_workers,
device=device,
world_size=args.world_size,
)
for epoch in range(epoch):
for metrics in trainer.train(epoch) # train 1 epoch
print(metrics)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config_file", type=str)
parser.add_argument("--epochs", type=int)
parser.add_argument("--num_workers", type=int)
parser.add_argument("--rank", type=int, default=-1)
parser.add_argument("--world_size", type=int, default=1)
parser.add_argument("--master_url", type=str)
args = parser.parse_args()
train(args)
Abstracts of my Trainer class
class Trainer:
def __init__(self, [GIVEN ARGS]):
self.model = model
self.dataset = dataset
self.augmentations = augmentations
self.criterion = criterion
self.num_workers = num_workers // world_size
self.device = device
self.world_size = world_size
self.batch_size = batch_size // world_size
self.sampler = DistributedSampler(
dataset=dataset,
shuffle=dataset.is_trainset(),
)
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=self.batch_size if dataset.is_trainset() else 1,
num_workers=self.num_workers,
pin_memory=True,
drop_last=dataset.is_trainset() is False,
sampler=self.sampler
)
def train(epoch):
self.model.train()
self.sampler.set_epoch(epoch)
for x, y in self.dataloader:
self.optimizer.zero_grad()
x = x.to(self.device, non_blocking=True)
y = y.to(self.device, non_blocking=True)
x = self.augmentations(x)
y_hat = self.model(x)
cost = self.criterion(input=y, target=y_hat)
cost.backward()
self.optimizer.step()
del x, y, y_hat
yield cost.item()
run.py
for running processes.
The script is quite identical with that of the referenced repository’s
def main():
args = sys.argv[1:] # arguments for __main__.py
gpus = torch.cuda.device_count() # supposed to be 1 in my case.
free_port = get_free_port()
master_url = f'127.0.0.1:{free_port}'
args += ["--world_size", str(gpus), "--master_url", f"127.0.0.1:{port}"]
tasks = []
for gpu in range(gpus):
if gpu > 0:
tasks.append(sp.Popen(["python3", "-m", "my_model"] + args + ["--rank", str(gpu)]))
tasks[-1].rank = gpu
while tasks:
for task in tasks:
try:
exitcode = task.wait(0.1)
except sp.TimeoutExpired:
continue
else:
tasks.remove(task)
if exitcode:
print(f"Task {task.rank} died with exit code "
f"{exitcode}",
file=sys.stderr)
failed = True
if failed:
break
if failed:
for task in tasks:
task.terminate()
sys.exit(1)
if __name__ == "__main__":
main()
Thanks again.