multiprocessing.get_context('spawn').Pool creates too many processes

I am trying to train multiple models in parallel on a single GPU. Monitoring the environment during running I can see a couple of issues that I did not expect:

  1. All of the tasks start running immediately, instead of being limited to 3;
  2. Only one of the tasks actually does anything and when it finishes, the process fails with an error.

This fills up GPU memory with do-nothing processes that eventually fail. I don’t know why this isn’t working as I expect, which is for it to train three models in parallel using separate processes and then, as those processes complete, start the other two processes.

A script to reproduce the issue:

import os
import time
import logging
from logging import FileHandler
from queue import Queue

from tqdm import tqdm
import torch
import torch.multiprocessing as multiprocessing

from typing import Optional

    
class FakeDataset(torch.utils.data.Dataset):
    def __init__(self, n_data:int = 250000, target: int = 30):
        self.n_data = n_data
        self.target = target
        
        self.X, self.y = self.get_data()
        
    def __len__(self) -> int:
        return self.X.shape[0]
    
    def __getitem__(self, idx) -> (torch.tensor, torch.tensor):
        return self.X[idx], self.y[idx]
    
    def get_data(self) -> (torch.tensor, torch.tensor):
        return (
            torch.rand((self.n_data, 4032, 2)).share_memory_(),
            torch.rand(self.n_data).share_memory_()
        )
    
    @property
    def size(self) -> float:
        """ total stored data size in GiB """
        return sum([u.element_size() * u.nelement() / 2 ** 30 for u in [self.X, self.y]])

    
class LSTMRegressor(torch.nn.Module):
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 num_layers,
                 output_dim,
                 cuda: bool = False
                 ):
        super().__init__()
        self.hidden = None  # to save the hidden state
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.output_dim = output_dim

        self.lstm = torch.nn.LSTM(input_size=input_dim,
                                  hidden_size=hidden_dim,
                                  num_layers=num_layers,
                                  batch_first=True  # easy use of Dataset
                                  )

        # The linear layer that maps from hidden state space to tag space
        self.linear = torch.nn.Linear(hidden_dim, output_dim)

        if cuda and torch.cuda.is_available():
            self._move_to_cuda()

    def forward(self, x: torch.tensor):
        lstm_out, self.hidden = self.lstm(x)
        y = self.linear(lstm_out[:, -1, :])  # only one prediction per data point
        return y.flatten()  # many-to-one predictions

    def _move_to_cuda(self):
        cuda_str = 'cuda:{:d}'.format(torch.cuda.current_device())
        self.to(cuda_str)
        
    @property
    def device(self):
        return self.linear.weight.device


DEFAULT_PARAMS = {
    'learning_rate': 1e-2, 
    'weight_decay': 0e-3, 
    'batch_size': 1024, 
    'hidden_dim': 8, 
    'num_layers': 1
}
    

def train_lstm(
    train_dataset: FakeDataset, 
    val_dataset: FakeDataset,
    save_dir: str,
    params: Optional[dict] = DEFAULT_PARAMS,
    index_number: int = 0
):
    USE_CUDA = True
    pid = str(multiprocessing.current_process().pid)
    
    save_dir = os.path.join(save_dir, f"{index_number:03d}")
    os.makedirs(save_dir, exist_ok=True)
    
    ### configure logging
    fmt_str = '%(asctime)s | %(name)-18s | %(levelname)-8s | %(message)s'
    dt_str = "%Y-%m-%dT%H:%M:%S"
    formatter = logging.Formatter(fmt=fmt_str, datefmt=dt_str)
    logger = logging.getLogger(f"SCAN-proc-{pid:s}")
    logger.handlers = []  # prevents multiple logging
    logger.setLevel(logging.DEBUG)
    sfn = os.path.join(save_dir, 'scan_{:03d}.log'.format(index_number))
    cch = logging.FileHandler(sfn)
    cch.setFormatter(formatter)
    logger.addHandler(cch)
    logger.info('Logger initialized')
    logger.info('Run parameters:')
    for key, value in params.items():
        logger.info(f"\t{key:s}: {str(value):s}")
    
    batch_size = params['batch_size']
    logger.info(f"Training scenario index {index_number:d}")
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, 
        shuffle=True, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, 
        shuffle=True, pin_memory=True
    )
    
    network_params = {
        'input_dim': 2,
        'hidden_dim': params['hidden_dim'],
        'num_layers': params['num_layers'],
        'output_dim': 1,
        'cuda': USE_CUDA
    }

    nnet = LSTMRegressor(**network_params)
    
    optimizer = torch.optim.Adam(
        nnet.parameters(), 
        lr=params['learning_rate'], 
        weight_decay=params['weight_decay']
    )
    loss_function = torch.nn.MSELoss(reduction='none')
    
    logger.info('starting training')
    
    epoch = 0
    n_epochs = 2
    train_loss_tracker = []
    val_loss_tracker = []
    best_val_loss = 1e10
    
    for _ in tqdm(range(n_epochs), desc='training epoch'):
        train_examples = 0
        train_loss = 0.
        nnet.train()

        ### run through all the training data
        queue = Queue()
        # for progress bar, see https://stackoverflow.com/a/45808255/6024187
        pbar = tqdm(total=len(train_loader), desc='batch_idx', leave=False)
        train_iter = iter(train_loader)
        next_batch = next(train_iter) # start loading the first batch: X, y
        if USE_CUDA and torch.cuda.is_available():
            next_batch = [u.cuda(non_blocking=True) for u in next_batch]  # with pin_memory=True and non_blocking=True, this will copy data to GPU non blockingly
        queue.put(next_batch)
        batch_idx = 0
        while not queue.empty():
            # use the data
            data, targets = queue.get()
            try:  # try to put more data on the queue
                next_batch = next(train_iter) # X, y
                if USE_CUDA and torch.cuda.is_available():
                    next_batch = [u.cuda(non_blocking=True) for u in next_batch]
                queue.put(next_batch)
            except StopIteration:
                pass

            ### do some training here
            optimizer.zero_grad()
            outputs = nnet(data)
            loss = loss_function(outputs, targets)  # compute loss
            loss.mean().backward()  # backprop
            optimizer.step()  # training step
            train_examples += len(targets)
            train_loss += loss.sum().item()

            batch_idx += 1
            pbar.update(n=1)

        pbar.close()

        train_loss /= (train_examples + (1 if train_examples == 0 else 0))  # compute the mean loss
        train_loss_tracker.append((epoch, train_loss))
        
        if (epoch + 1) % 1 == 0 or epoch == 0:
            val_examples = 0
            val_loss = 0.
            nnet.eval()
            
            ### run through all the validation data
            vq = Queue()
            # for progress bar, see https://stackoverflow.com/a/45808255/6024187
            vbar = tqdm(total=len(val_loader), desc='val_idx', leave=False)
            val_iter = iter(val_loader)
            next_batch = next(val_iter) # start loading the first batch: X, y
            if USE_CUDA and torch.cuda.is_available():
                next_batch = [u.cuda(non_blocking=True) for u in next_batch ]  # with pin_memory=True and non_blocking=True, this will copy data to GPU non blockingly
            vq.put(next_batch)
            val_idx = 0
            while not vq.empty():
                 # use the data
                vdata, vtargets = vq.get()
                try:  # try to put more data on the queue
                    next_batch = next(val_iter) # X, y
                    if USE_CUDA and torch.cuda.is_available():
                        next_batch = [u.cuda(non_blocking=True) for u in next_batch]
                    queue.put(next_batch)
                except StopIteration:
                    pass
                with torch.no_grad():
                    voutputs = nnet(vdata)  # compute predictions
                    vloss = loss_function(voutputs, vtargets)  # compute loss
                    val_examples += len(vtargets)
                    val_loss += vloss.sum().item()
                
                val_idx += 1
                vbar.update(n=1)
                
            vbar.close()

            val_loss /= (val_examples + (1 if val_examples == 0 else 0))  # compute the mean loss
            val_loss_tracker.append((epoch, val_loss))

            ss = 'Train Epoch: {: 3d}, Train Loss: {:.1f}, Val Loss: {:.1f}'.format(
                    epoch,
                    train_loss,
                    val_loss
                )
            if val_loss < best_val_loss:
                # save model checkpoint here
                # see https://wandb.ai/wandb/common-ml-errors/reports/How-to-save-and-load-models-in-PyTorch--VmlldzozMjg0MTE#save-a-pytorch-model-checkpoint
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': nnet.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': val_loss,
                    'train_loss': train_loss
                }
                ptfn = os.path.join(save_dir, 'best_model_checkpoint.pth')
                torch.save(checkpoint, ptfn)
                ss = ss + '  <<< model saved'
                best_val_loss = val_loss
            
            logger.info(ss)
            
        epoch += 1
        
    logger.info('done training')
    logger.info('cleaning up the nnet and optimizer')
    del nnet
    del optimizer
    logger.info('waiting 15 seconds for memory to be freed')
    time.time(15.)
     
        
if __name__ == "__main__":
    SAVE_DIR = 'scans/zero'
    
    os.makedirs(SAVE_DIR, exist_ok=True)
    
    # these use shared_memory_() to share the tensors between processes
    # by reference, not by copy
    train_dataset = FakeDataset()
    val_dataset = FakeDataset()
    
    # everything after this point shoule be in a parallel process
    # accept the datasets and some parameters as arguments and then be on your way
    learning_rates = [1e-4, 1e-3, 1e-2]
    batch_sizes = [128, 256, 512, 1024]
    hidden_dims = [4, 8, 16]
    num_layerses = [1, 2, 3]
                        
    params = {
        'learning_rate': 1e-2, 
        'weight_decay': 1e-3, 
        'batch_size': 1024, 
        'hidden_dim': 16, 
        'num_layers': 3
    }
                        
    tasks = [
        (train_dataset, val_dataset, SAVE_DIR, params, 0),
        (train_dataset, val_dataset, SAVE_DIR, params, 1),
        (train_dataset, val_dataset, SAVE_DIR, params, 2),
        (train_dataset, val_dataset, SAVE_DIR, params, 3),
        (train_dataset, val_dataset, SAVE_DIR, params, 4),
    ]
    
    with multiprocessing.get_context('spawn').Pool(processes=3, maxtasksperchild=1) as pool:
        results = pool.starmap(train_lstm, tasks)

The error after the sole running process finishes:

multiprocessing.pool.RemoteTraceback:                                                                                                            
"""
Traceback (most recent call last):
  File "/python3.9/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/python3.9/multiprocessing/pool.py", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
  File "fake_parallel.py", line 184, in train_lstm
    loss.mean().backward()  # backprop
  File "/python3.9/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/python3.9/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 4.92 GiB. GPU 0 has a total capacity of 39.39 GiB of which 4.25 GiB is free. Process 1363528 has 5.25 GiB memory in use. Including non-PyTorch memory, this process has 9.97 GiB memory in use. Process 1394002 has 9.95 GiB memory in use. Process 1394001 has 9.95 GiB memory in use. Of the allocated memory 5.06 GiB is allocated by PyTorch, and 4.39 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "fake_parallel.py", line 302, in <module>
    results = pool.starmap(train_lstm, tasks)
  File "/python3.9/multiprocessing/pool.py", line 372, in starmap
    return self._map_async(func, iterable, starmapstar, chunksize).get()
  File "/python3.9/multiprocessing/pool.py", line 771, in get
    raise self._value
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 4.92 GiB. GPU 0 has a total capacity of 39.39 GiB of which 4.25 GiB is free. Process 1363528 has 5.25 GiB memory in use. Including non-PyTorch memory, this process has 9.97 GiB memory in use. Process 1394002 has 9.95 GiB memory in use. Process 1394001 has 9.95 GiB memory in use. Of the allocated memory 5.06 GiB is allocated by PyTorch, and 4.39 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
(conda_env) [user@hostname dir]$ /python3.9/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 2 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Catching exceptions during the run indicates that some of the processes are running out of memory and ending silently and then further processes are starting.