Torch.multiprocessing.spawn hangs after completion

I have some code where I need to spawn new process groups several times within a loop. On each iteration, I want to create the new process group and then destroy it.

My code snippet is below:

# Using ../sensitivity/36test.py
# Functions whose Fourier degree is concentrated on higher weights are harder to learn for LSTMs with SGD

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import random
import argparse
from transformer import Transformer
import os
import itertools
import time
import datetime
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group


device = torch.device("cuda")


def ddp_setup(rank, world_size):
    os.environ["MASTER_ADDR"]="localhost"
    os.environ["MASTER_PORT"]= "12355"
    init_process_group(backend="nccl",rank=rank, world_size=world_size,timeout=datetime.timedelta(seconds=60))

class Trainer:
    def __init__(
            self,
            model:torch.nn.Module,
            train_data: DataLoader,
            optimizer: torch.optim.Optimizer,
            gpu_id: int,
           
    ) -> None:
        self.gpu_id = gpu_id
        self.model = DDP(model,device_ids=[self.gpu_id])
        self.model.to(self.gpu_id)
        self.train_data=train_data
        self.optimizer = optimizer
    
    def makeBitTensor(self, x, N):
        y = format(x, "b")
        y = ("0"*(N-len(y))) + y
        return [int(z) for z in list(y)]
        
    def func_batch(self,x):
        # real code has some fancier function here
        return torch.tensor(np.array([2]*len(x))).to(self.gpu_id)
        
    def _run_batch(self,inputs, targets):
        self.optimizer.zero_grad()
        inputs.to(self.gpu_id)
        result = self.model(inputs)
        #loss = -(result*targets).mean()
        loss =  (result-targets).pow(2).mean()
        (loss).backward()
        self.optimizer.step()
        return loss.detach().cpu()
    
    def _run_epoch(self,epoch):
        
        #b_sz = len(next(iter(self.train_data))[0])
        b_sz = len(next(iter(self.train_data)))
        #print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        epoch_loss = 0
        total_records = 0
        start_time = time.time()
        
        for idx, inputs in enumerate(self.train_data):
          #inputs.to(self.gpu_id)    
          targets =self.func_batch(inputs).to(self.gpu_id)
          batch_loss = self._run_batch(inputs, targets)
          epoch_loss+=batch_loss*float(len(inputs))
          total_records+=len(inputs)
          iteration = epoch*len(self.train_data)+idx+1
            
        epoch_loss/=float(total_records)
        
        end_time = time.time()
        
        elapsed_time = end_time - start_time
        #print(f"Epoch time: {elapsed_time:.3f} seconds.")
        time_per_record_ms = float(elapsed_time*100)/float(total_records)
        #print(f"Epoch time: {elapsed_time:.3f} seconds. time per record (ms): {time_per_record_ms: .3f}")
        return epoch_loss

    def save_checkpoint(self,epoch):
        ckp = self.model.module.state_dict()
        torch.save(ckp,os.path.join(self.dir_name, f"model_{epoch}.pt"))
        print(f"Epoch {epoch} | Training checkpoint saved at model_{epoch}.pt")

    def train(self,epochs: int):
        self.model.train()
        for epoch in range(epochs):
            epoch_loss = self._run_epoch(epoch)
            if epoch_loss < 0.01:
                return	
            #print("remainder: " + str(epoch % self.save_every))
            if (epoch % 20)==0 and self.gpu_id==0:
                val_loss = self.validate(1000)
                print(f" Epoch: {epoch}, EpochLoss: {epoch_loss:.3f}, ValidationLoss: {val_loss:.3f}")
        return

    def validate(self, num_samples):
          self.model.eval()
          inputs = torch.tensor([random.randint(0, 2**self.N-1) for _ in range(num_samples)]).to(self.gpu_id)
          targets = self.func_batch(inputs).to(self.gpu_id)
          result = self.model(inputs).to(self.gpu_id)
          loss = (result - targets).pow(2).mean()
          return loss.detach().cpu()
    
def load_train_objs(wd,dropout,lr,num_samples, N, dim,h,l,f,rank):
        train_set = torch.tensor([random.randint(0, 2**N-1) for _ in range(int(num_samples))]).to(rank)

        model = Transformer(dropout,N, dim, h, l, f, 1e-5,rank)
        optimizer = torch.optim.AdamW(model.parameters(), lr=float(lr), weight_decay=wd)
        return train_set, model, optimizer                


def parse_args():
    parser = argparse.ArgumentParser(description='linear spectrum non boolean test.')
    parser.add_argument('--world_size', type=int, default=8)
    parser.add_argument('--dim', type=int, default=20)
    parser.add_argument('--f', type=int, default=64)
    parser.add_argument('--l', type=int, default=1)
    parser.add_argument('--h', type=int, default=1)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--bs', type=int, default=32)
    parser.add_argument('--num_samples', type=int, default=100000)
    parser.add_argument('--lr', type=str,default = "6e-6")
    parser.add_argument('--wd', type=float,default = .1)
    parser.add_argument('--dropout', type=float,default = .2)
    parser.add_argument('--repeat', type=int, default=100)


    return parser.parse_args()

def main(rank, args,world_size):
      ddp_setup(rank,world_size)
      train_set,model,optimizer = load_train_objs(args.dropout, args.wd,args.lr,args.num_samples,args.N,args.dim,args.h,args.l,args.f,rank)
      model.to(rank)
      train_loader = DataLoader(
          train_set,
          shuffle=False,
          batch_size=args.bs,
          sampler = DistributedSampler(train_set)
      )
             
      trainer = Trainer( model,
                        train_loader,
                        optimizer,
                        gpu_id=rank,
                        )
      trainer.train(args.epochs)
      print("finished training, cleaning up process group...")
      destroy_process_group()
      print("finished cleaning up process group")
      return

if __name__ == "__main__":
    arguments = parse_args()
    print(arguments)
    
    losses = {}
    func_per_deg = arguments.repeat
    main_dir = f"N{arguments.N}_HidDim{arguments.dim}_L{arguments.l}_H{arguments.h}_FFDim{arguments.f}_4k"
    os.makedirs(main_dir, exist_ok=True)
  # with open("logs_width.txt", "a") as f:
  #   f.write("------------------------------------------\n")
    for i in range(2):
        start_time = time.time()
        world_size = torch.cuda.device_count()
        mp.set_start_method('spawn',force = True)
        torch.set_num_threads(1)
        mp.spawn(main,args=(arguments,world_size),nprocs=arguments.world_size).join()
        print("returned from mp.spwan")
        end_time = time.time()
        elapsed_time = round((end_time - start_time)/60,3)
        print("elapsed time for whole training process: " + str(elapsed_time))

When I run this code, it trains successfully using all 8 GPUs on the cluster I’m working on. After training the first iteration over i in the outer main function, it completes everything in the inner main function and prints out “finished cleaning up process group”. However, it never actually returns from mp.spawn and just hangs until it times out.

If you look at my code I have implemented all of the easily findable workarounds to this problem – using join() (I also tried passing join=True as a parameter to spawn), setting the start method to “spawn”, using only one thread, etc. But none of them do anything – it just hangs indefinitely.

Any help understanding this problem would be greatly appreciated!

-Paul