Subprocess and CUDA Memory profile

Hello

I’m currently trying to create a subprocess that train and profile the memory of GPUs during a train.

Here is the script launch :

from gpu_alloc import TraceMalloc
from dataset import PipelineDataset
from pipelinecache.layered_model import PipelinedModel
import os

import torch
import argparse
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29600'

parser = argparse.ArgumentParser()
parser.add_argument('--input_shape', type=str, help='Input shape as a list')
parser.add_argument('--output_shape', type=str, help='Output shape as a list')
parser.add_argument('--number_gpu', type=int, help='Number of GPU')
parser.add_argument('--number_chunks', type=int, help='Number of chunks')
args = parser.parse_args()

input_shape = args.input_shape.replace("[", "").replace("]", "")
input_shape = input_shape.split(",")
input_shape = [int(x.strip()) for x in input_shape]

output_shape  = args.output_shape.replace("[", "").replace("]", "")
output_shape  = output_shape.split(",")
output_shape  = [int(x.strip()) for x in output_shape]

number_gpus   = args.number_gpu
number_chunks = args.number_chunks


trace_gpu_alloc = TraceMalloc(number_gpus)
criterion = torch.nn.CrossEntropyLoss()

torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1)

with trace_gpu_alloc:

    model = PipelinedModel()

    dataset = PipelineDataset(1024, input_shape[1:], [1] if len(output_shape) == 1 else output_shape[1:])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=input_shape[0], shuffle=True)

    model = model.get_modules()
    model = torch.distributed.pipeline.sync.Pipe(model, number_chunks)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    for _ in range(3):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            inputs = inputs.to(0)
            labels = labels.to(number_gpus- 1)
            
            outputs = model(inputs).local_value()

            # Forward pass
            loss = criterion(outputs, labels.squeeze())

            # Backward pass et mise à jour des poids
            loss.backward()
print(trace_gpu_alloc.peaked)

and the class that profile memory :

import torch

def BToMb(x): return (x // (2*1024))

class TraceMalloc():
    def __init__(self, nb_gpu):
        self.nb_gpu = nb_gpu
        self.begin  = [0] * nb_gpu
        self.end    = [0] * nb_gpu
        self.peak   = [0] * nb_gpu
        self.peaked = [0] * nb_gpu

    def __enter__(self):
        for device in range(self.nb_gpu):
            self.begin[device] = torch.cuda.memory_allocated(device)
            
        return self
    
    def __exit__(self, *exc):

        for device in range(self.nb_gpu):
            self.end[device]    = torch.cuda.memory_allocated(device)
            self.peak[device]   = torch.cuda.max_memory_allocated(device)
            self.peaked[device] = BToMb(self.peak[device] - self.begin[device])
            torch.cuda.reset_peak_memory_stats(device)

        for device in range(self.nb_gpu):
            print(f"GPU n°{device}")
            print(f"    Memory begin -> {BToMb(self.begin[device])} MB")
            print(f"    Memory end   -> {BToMb(self.end[device])} MB")
            print(f"    Memory peak  -> {BToMb(self.peak[device])} MB")
            print(f"    Memory peaked  -> {self.peaked[device]} MB")

If I run the script independently on a shell it will work fine. But when I try to call it from an other process it will only return 0 value for all memory information.

How I call it from subprocess :

p = subprocess.run(['python', dir_path,
                           '--input_shape', str(list(self.input_shape)),
                           '--output_shape', str(list(self.output_shape)),
                           '--number_gpu', str(int(self.nb_gpu)),
                           '--number_chunks', str(2)], capture_output=True, text=True)
print(p.stdout)

I’ve tried to enable the shell option but my program is locked. I also verified if there was not two different processes but only one acceed my GPUs while subprocess runs…

I do not know if anyone face this or have any advice :frowning: