Infinite inference when using multiprocessing with cuda for multiple models

Hi!

I am trying to utilise AMD GPU for inference speedup for genetic learning: I have multiple servers that act like “game” environments, each server has its socket and is capable of “playing” a map, which I have multiple of, graph NN and “agents”: instances of GeneticLearner class, capable of fitting its internal weight vector, acting as the last layer of NN.

Originally I used multiprocessing to distribute the load, which worked fine, but when I try to pass the model and data to GPU, it results in an infinite waiting time in forward function and it looks like GPU is idle.

Is there any way I can fix this issue? CUDA+multiprocessing best practices didn’t help

my code looks roughly like this:

def play_game(model, game_map, ws_queue: queue.Queue[websocket.Websocket]):
    ...
    # get socket from queue, play game, put socket back, return results

def init_fun():
    ...
    # assign torch model to static field "MODEL" of GeneticLearner

future_queue = queue.SimpleQueue()
with mp.Pool(proc_num, init_fun) as executor:
    for model, game_map in games:
        ...
        future = executor.apply_async(
            play_game,
            [model, game_map, max_steps, ws_urls], # ws_urls is a torch.multiprocessing.Manager.Queue
            ...
        )
        futures_queue.put(future)

    while not futures_queue.empty():
        ...  # get results

I load model like this:

def load_model(path: str) -> torch.nn.Module:
    model = ml.models.StateModelEncoder(hidden_channels=64, out_channels=8)
    model.load_state_dict(torch.load(path), strict=False)
    model.to(torch.device("cuda:0"))
    model.eval()
    return model

and use it like this:

device = torch.device("cuda:0")
data.to(device)  # data is torch_geometric.data.HeteroData
out = model.forward(data.x_dict, data.edge_index_dict)

My environment:
OS: Ubuntu 22.04
GPU: AMD Radeon Vega Frontier Edition
Conda Python3.10.9
packages:
pytorch-triton-rocm 2.0.1 pypi_0 pypi
torch 2.0.1+rocm5.4.2 pypi_0 pypi
torch-geometric 2.3.1 pypi_0 pypi