I understand that MPS is the way in which Nvidia supports CUDA multithreading/multiprocessing.
hmm we need to be more specific. Each process receives its own cuda context on each device used by the process. Per-device contexts are shared by all CPU threads within the process. Any CPU thread in the process may submit work to any cuda stream (the kernel launch and stream API are thread safe), and the work may run concurrently with work submitted from other CPU threads. And of course, each kernel may use thousands of GPU threads.
By default (without MPS) each device runs kernels from only one context (process) at a time. If several processes target the same device, their kernels can’t run concurrently and GPU context switches between processes will occur. MPS multiplexes kernels from different processes so kernels from any thread of any process targeting that device CAN run concurrently (not sure how MPS works at a low level, but it works).
MPS is application-agnostic. After starting the MPS daemon in your shell:
nvidia-cuda-mps-control –d
all processes (Python or otherwise) that use the device have their cuda calls multiplexed so they can run concurrently. You shouldn’t need to do anything pytorch-specific: start the MPS daemon in the background, then launch your pytorch processes targeting the same device.
One thing I don’t know is whether nccl allreduces in Pytorch can handle if data from all processes is actually on one GPU. I’ve never seen it tried. Sounds like your case doesn’t need inter-process nccl comms though.
MPS has been around for years, and works on any recent generation. It is NOT the same thing as “multi-instance GPU” or MIG, which is Ampere-specific. (I think MIG sandboxes client processes more aggressively than MPS, providing better per-process fault isolation among other things. MPS should be fine for your case.)