Shared Cuda Tensor Consumes GPU Memory

I tried to pass a cuda tensor into a multiprocessing spawn. As per my understanding, it will automatically treat the cuda tensor as a shared memory as well (which is supposed to be a no op according to the docs). However, it turns out that such operation makes PyTorch to be unable to reserve quite a significant memory size of my GPUs (2-3 GBs) – which probably is the reserved storage to make the shared memory works. Is this expected?

Could you post a minimal code snippet showing your issue, please?

Please refer to the code in [1] to reproduce the issue.

Based on my own experiment with 4 GPUs of 16 GB each:

Running: “python3 src/test_shared.py” finishes with:

[P0] [GPU Mem (MB)] total: 16130.5, reserved = 13826.0, allocated: 13824.0625, free: 1.9375, step: 99
[P1] [GPU Mem (MB)] total: 16130.5, reserved = 13826.0, allocated: 13824.0625, free: 1.9375, step: 99
[P2] [GPU Mem (MB)] total: 16130.5, reserved = 13826.0, allocated: 13824.0625, free: 1.9375, step: 99
[P3] [GPU Mem (MB)] total: 16130.5, reserved = 13826.0, allocated: 13824.0625, free: 1.9375, step: 99

However, running “python3 src/test_shared.py --shm_cuda”, throws the following error:

[P1] [GPU Mem (MB)] total: 16130.5, reserved = 9858.0, allocated: 9856.0625, free: 1.9375, step: 68
Traceback (most recent call last):
  File "[...]/test_shared.py", line 84, in <module>
    main()
  File "[...]/test_shared.py", line 80, in main
    mp.spawn(train, nprocs=num_gpus, args=(args, num_gpus, shm_list))
  File "[...]/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "[...]/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "[...]/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "[...]/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "[...]/test_shared.py", line 53, in train
    inp = torch.rand((args.batch_size, args.inp_size)).to(rank)
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.75 GiB total capacity; 9.63 GiB already allocated; 59.88 MiB free; 9.63 GiB reserved in total by PyTorch)

As could be seen, PyTorch seems to not be able to reserve more memory? The shared tensors are only: (2^10) * (2^14) in size each though… Any thoughts on how it could end up taking >4GB?

[1] Sample code

import os
from argparse import ArgumentParser

import torch
import torch.multiprocessing as mp
import torch.distributed as dist

def load_args():
    parser = ArgumentParser(add_help=False)
    parser.add_argument('--gpu_limit', type=int, default=4, help='Max num of GPUs used')
    parser.add_argument('--num_steps', type=int, default=100, help='Number of training steps')
    parser.add_argument('--batch_size', type=int, default=(2**10), help='Batch size')
    parser.add_argument('--inp_size', type=int, default=(2**14), help='Input size')
    parser.add_argument('--out_size', type=int, default=(2**14), help='Output size')
    parser.add_argument('--shm_cuda', action='store_true')

    args = parser.parse_args()
    return args


def init_dist(rank, world_size, backend="nccl"):
    master_addr = os.environ['MASTER_ADDR']
    master_port = os.environ['MASTER_PORT']

    dist.init_process_group(
        backend,
        init_method="tcp://%s:%s" % (master_addr, master_port),
        rank=rank,
        world_size=world_size)


def log_gpu_mem(local_rank, step, unit="MB"):
    u2d = {"KB": 1024, "MB": 1024 ** 2}
    d = u2d[unit]

    total = torch.cuda.get_device_properties(local_rank).total_memory / d
    reserved = torch.cuda.memory_reserved(local_rank) / d
    allocated = torch.cuda.memory_allocated(local_rank) / d
    free = reserved - allocated

    s = f"[GPU Mem ({unit})] total: {total}, reserved = {reserved}, allocated: {allocated}, free: {free}"
    print(f"[P{local_rank}] " + s + f", step: {step}")


def train(rank, args, world_size, shm_list):
    torch.cuda.set_device(rank)
    init_dist(rank, world_size)

    model = torch.nn.Linear(args.inp_size, args.out_size).to(rank)

    record = []
    for step in range(args.num_steps):
        inp = torch.rand((args.batch_size, args.inp_size)).to(rank)
        out = model(inp)
        record.append(out)

        shm_list[step % len(shm_list)].data.copy_(out.data)

        log_gpu_mem(rank, step)


def main():
    num_gpus = torch.cuda.device_count()
    if num_gpus < 2:
        raise Exception("At least 2 GPUs are required to run this script")

    args = load_args()
    num_gpus = min(num_gpus, args.gpu_limit)

    print(f"Using {num_gpus} GPUs for testing")

    shm_list = []
    for i in range(num_gpus):
        shm = torch.zeros((args.batch_size, args.out_size))
        if args.shm_cuda:
            shm = shm.to(i)

        shm_list.append(shm)

    mp.spawn(train, nprocs=num_gpus, args=(args, num_gpus, shm_list))


if __name__ == '__main__':
    main()

Also, maybe does anyone have any reference on what’s happening behind the scene when a Cuda tensor is being converted into shared memory – the docs seem to say that its a no-op; however, is it really that the reference to the GPU memory location is just being passed into multiple processes? Would be great if there’s any recommended reading to better understand about this

Sorry, I don’t quite understand your use case as you are not using tensor.share_memory(), which seems to be the op you are referring to.
In any case, I also see identical memory usage in both use cases (i.e. with the additional argument and without).

I believe that all tensors passed into multiprocessing spawn will be converted into shared memory?
According to [1], that’s the case when using multiprocessing queue. I actually experimented with this as well before and I believe that the tensors passed as args when spawning sub-processes are indeed converted to shared memory. Otherwise, can also try to add shm.share_memory_() before appending it to the list, and I believe the result will still be the same.

Besides, thanks for taking a look into it! Do you perhaps have any thoughts on what might cause such a different in memory usage?

[1] Multiprocessing best practices — PyTorch 1.9.1 documentation

No, since I’m seeing the same memory usage and thus cannot reproduce the issue.
Could you post your output of both runs for the first ~10 steps?

Oh, sorry for not communicating this clearly earlier – but I think the issue that I raised in this thread is that PyTorch could not reserve as much memory as what it did for the one without “–shm_cuda” arg.

If you take a look at my sample output before throwing the error:

[P1] [GPU Mem (MB)] total: 16130.5, reserved = 9858.0, allocated: 9856.0625, free: 1.9375, step: 68
...
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.75 GiB total capacity; 9.63 GiB already allocated; 59.88 MiB free; 9.63 GiB reserved in total by PyTorch)

After reserving for ~10GB, it went out of memory even though the total available is ~16GB (there was no other process using the GPU). Meanwhile, in the earlier case you could see that the program managed to finish with reserving ~13GB

[P0] [GPU Mem (MB)] total: 16130.5, reserved = 13826.0, allocated: 13824.0625, free: 1.9375, step: 99

The memory usages are indeed the same for both use cases. It’s just that somehow PyTorch couldn’t reserve the same amount of memory.

I purposely made the sample code to get out of memory. Hence if you have GPUs with diff total memory, prob need to reduce/increase the number of steps to be able to reproduce.