I have some code where I need to spawn new process groups several times within a loop. On each iteration, I want to create the new process group and then destroy it.
My code snippet is below:
# Using ../sensitivity/36test.py
# Functions whose Fourier degree is concentrated on higher weights are harder to learn for LSTMs with SGD
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import random
import argparse
from transformer import Transformer
import os
import itertools
import time
import datetime
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group
device = torch.device("cuda")
def ddp_setup(rank, world_size):
os.environ["MASTER_ADDR"]="localhost"
os.environ["MASTER_PORT"]= "12355"
init_process_group(backend="nccl",rank=rank, world_size=world_size,timeout=datetime.timedelta(seconds=60))
class Trainer:
def __init__(
self,
model:torch.nn.Module,
train_data: DataLoader,
optimizer: torch.optim.Optimizer,
gpu_id: int,
) -> None:
self.gpu_id = gpu_id
self.model = DDP(model,device_ids=[self.gpu_id])
self.model.to(self.gpu_id)
self.train_data=train_data
self.optimizer = optimizer
def makeBitTensor(self, x, N):
y = format(x, "b")
y = ("0"*(N-len(y))) + y
return [int(z) for z in list(y)]
def func_batch(self,x):
# real code has some fancier function here
return torch.tensor(np.array([2]*len(x))).to(self.gpu_id)
def _run_batch(self,inputs, targets):
self.optimizer.zero_grad()
inputs.to(self.gpu_id)
result = self.model(inputs)
#loss = -(result*targets).mean()
loss = (result-targets).pow(2).mean()
(loss).backward()
self.optimizer.step()
return loss.detach().cpu()
def _run_epoch(self,epoch):
#b_sz = len(next(iter(self.train_data))[0])
b_sz = len(next(iter(self.train_data)))
#print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
epoch_loss = 0
total_records = 0
start_time = time.time()
for idx, inputs in enumerate(self.train_data):
#inputs.to(self.gpu_id)
targets =self.func_batch(inputs).to(self.gpu_id)
batch_loss = self._run_batch(inputs, targets)
epoch_loss+=batch_loss*float(len(inputs))
total_records+=len(inputs)
iteration = epoch*len(self.train_data)+idx+1
epoch_loss/=float(total_records)
end_time = time.time()
elapsed_time = end_time - start_time
#print(f"Epoch time: {elapsed_time:.3f} seconds.")
time_per_record_ms = float(elapsed_time*100)/float(total_records)
#print(f"Epoch time: {elapsed_time:.3f} seconds. time per record (ms): {time_per_record_ms: .3f}")
return epoch_loss
def save_checkpoint(self,epoch):
ckp = self.model.module.state_dict()
torch.save(ckp,os.path.join(self.dir_name, f"model_{epoch}.pt"))
print(f"Epoch {epoch} | Training checkpoint saved at model_{epoch}.pt")
def train(self,epochs: int):
self.model.train()
for epoch in range(epochs):
epoch_loss = self._run_epoch(epoch)
if epoch_loss < 0.01:
return
#print("remainder: " + str(epoch % self.save_every))
if (epoch % 20)==0 and self.gpu_id==0:
val_loss = self.validate(1000)
print(f" Epoch: {epoch}, EpochLoss: {epoch_loss:.3f}, ValidationLoss: {val_loss:.3f}")
return
def validate(self, num_samples):
self.model.eval()
inputs = torch.tensor([random.randint(0, 2**self.N-1) for _ in range(num_samples)]).to(self.gpu_id)
targets = self.func_batch(inputs).to(self.gpu_id)
result = self.model(inputs).to(self.gpu_id)
loss = (result - targets).pow(2).mean()
return loss.detach().cpu()
def load_train_objs(wd,dropout,lr,num_samples, N, dim,h,l,f,rank):
train_set = torch.tensor([random.randint(0, 2**N-1) for _ in range(int(num_samples))]).to(rank)
model = Transformer(dropout,N, dim, h, l, f, 1e-5,rank)
optimizer = torch.optim.AdamW(model.parameters(), lr=float(lr), weight_decay=wd)
return train_set, model, optimizer
def parse_args():
parser = argparse.ArgumentParser(description='linear spectrum non boolean test.')
parser.add_argument('--world_size', type=int, default=8)
parser.add_argument('--dim', type=int, default=20)
parser.add_argument('--f', type=int, default=64)
parser.add_argument('--l', type=int, default=1)
parser.add_argument('--h', type=int, default=1)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--bs', type=int, default=32)
parser.add_argument('--num_samples', type=int, default=100000)
parser.add_argument('--lr', type=str,default = "6e-6")
parser.add_argument('--wd', type=float,default = .1)
parser.add_argument('--dropout', type=float,default = .2)
parser.add_argument('--repeat', type=int, default=100)
return parser.parse_args()
def main(rank, args,world_size):
ddp_setup(rank,world_size)
train_set,model,optimizer = load_train_objs(args.dropout, args.wd,args.lr,args.num_samples,args.N,args.dim,args.h,args.l,args.f,rank)
model.to(rank)
train_loader = DataLoader(
train_set,
shuffle=False,
batch_size=args.bs,
sampler = DistributedSampler(train_set)
)
trainer = Trainer( model,
train_loader,
optimizer,
gpu_id=rank,
)
trainer.train(args.epochs)
print("finished training, cleaning up process group...")
destroy_process_group()
print("finished cleaning up process group")
return
if __name__ == "__main__":
arguments = parse_args()
print(arguments)
losses = {}
func_per_deg = arguments.repeat
main_dir = f"N{arguments.N}_HidDim{arguments.dim}_L{arguments.l}_H{arguments.h}_FFDim{arguments.f}_4k"
os.makedirs(main_dir, exist_ok=True)
# with open("logs_width.txt", "a") as f:
# f.write("------------------------------------------\n")
for i in range(2):
start_time = time.time()
world_size = torch.cuda.device_count()
mp.set_start_method('spawn',force = True)
torch.set_num_threads(1)
mp.spawn(main,args=(arguments,world_size),nprocs=arguments.world_size).join()
print("returned from mp.spwan")
end_time = time.time()
elapsed_time = round((end_time - start_time)/60,3)
print("elapsed time for whole training process: " + str(elapsed_time))
When I run this code, it trains successfully using all 8 GPUs on the cluster I’m working on. After training the first iteration over i in the outer main function, it completes everything in the inner main function and prints out “finished cleaning up process group”. However, it never actually returns from mp.spawn
and just hangs until it times out.
If you look at my code I have implemented all of the easily findable workarounds to this problem – using join()
(I also tried passing join=True as a parameter to spawn), setting the start method to “spawn”, using only one thread, etc. But none of them do anything – it just hangs indefinitely.
Any help understanding this problem would be greatly appreciated!
-Paul