Torch DDP crashes with OOM error for a model inference with multi GPU, when it runs perfectly well on a single GPU

I am using accelerate to perform multiGPU inference of openllama models (3b/13b). Both the models are able to do inference on a single GPU perfectly fine with a large batch size of 32. Since I have more than 1 GPU in my machine, I want to do parallel inference. For that, I used torch DDP and huggingface accelerate. Both of them crash with OOM eror for the 13b model and take 3X memory for the 3B model compared to when it is run on a single GPU. I understand that DDP might have some overhead (although I am only doing inference, no training), still 3X seems a bit odd. I have attached the minimal code to reproduce the error. The first few lines tell the exact command to run and the packages to be installed in the virtual environment. I have reproduced this issue in 3 separate machines. I am curious as to why this is happening and how can I resolve it. Thank you so much! Your help is very appreciated.

''''
Commands to run:
1. python mini_example.py --use_single_gpu
2. accelerate launch --multi_gpu --gpu_ids=0,1,2,3 --mixed_precision=fp16 --num_processes=2 --num_machines=1 --main_process_port=29500 mini_example.py --use_accelerate_1
3. python mini_example.py --use_ddp

Packages used:
pip install transformers
pip install torch
pip install accelerate
pip install sentencepiece
pip install protobuf
'''
import torch
import warnings
warnings.filterwarnings("ignore")
from transformers import LlamaTokenizer, LlamaForCausalLM
import os

model_path = 'openlm-research/open_llama_13b'

import subprocess
def get_gpu_memory_usage():
    try:
        result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'])
        return list(map(int, result.decode('utf-8').strip().split('\n')))
    except Exception as e:
        print(f"Error in fetching GPU status: {e}")
        return []


## This works with peak memory usage of 25Gb on a single GPU
def single_gpu(args):
    args.device = "cuda:0"
    tokenizer = LlamaTokenizer.from_pretrained(model_path)
    model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=args.device)

    prompt = 'Q: What is the largest animal?\nA:'
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    input_ids = input_ids.to(args.device)

    memory_usages = get_gpu_memory_usage()
    print(f"Memory usages: {memory_usages}")
    generation_output = model.generate(
        input_ids=input_ids, max_new_tokens=101
    )
    print(tokenizer.decode(generation_output[0]))


## This crashes with OOM error
def multi_gpu_accelerate_load_to_memory():
    from accelerate import Accelerator
    accelerator = Accelerator()
    model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
    model = accelerator.prepare(model)

    memory_usages = get_gpu_memory_usage()
    print(f"Memory usages: {memory_usages}")


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def worker_ddp(rank, args):
    args.device = "cuda:" + str(args.device_ids[rank])
    # tokenizer = LlamaTokenizer.from_pretrained(model_path)
    model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=args.device)
    dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:5438', world_size=args.num_devices, rank=rank)

    model = model.to(args.device)
    torch.cuda.set_device(args.device)
    model = DDP(model, device_ids=[args.device_ids[rank]])

    memory_usages = get_gpu_memory_usage()
    print(f"Memory usages: {memory_usages}")


## crashes with OOM error
def multi_gpu_ddp():
    args.device_ids = list(map(int, args.gpu_ids))
    args.num_devices = len(args.device_ids)
    import torch.multiprocessing as mp
    mp.set_start_method("spawn", force = True)
    os.environ["NCCL_P2P_DISABLE"] = "1"
    os.environ["NCCL_IB_DISABLE"] = "1"
    import time
    time.sleep(5)
    print("LAUNCHING DDP ON", args.num_devices, "GPUs: ", args.device_ids)
    mp.spawn(worker_ddp, nprocs=args.num_devices, args=(args,))


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--use_single_gpu", action="store_true")
    parser.add_argument("--use_accelerate", action="store_true")
    parser.add_argument("--use_ddp", action="store_true")
    parser.add_argument('--gpu_ids', type=str, nargs = "+", default=["0", "1", "2", "3"])
    args = parser.parse_args()
    if args.use_single_gpu:
        single_gpu(args)                            ## Works perfectly fine
    elif args.use_accelerate:
        multi_gpu_accelerate_load_to_memory()       ## Crashes with OOM error
    elif args.use_ddp:
        multi_gpu_ddp()                             ## Crashes with OOM error

It could be that there is memory overhead, but as first step, could you check if there is any difference when running your model with with torch.no_grad():? Otherwise you could potentially be storing intermediate activations which would waste memory.

Furthermore, DDP is somewhat overkill for parallel inference. I would check if simply manually splitting your batches and calling to(device='cuda:...') would suffice as even on a single thread the CUDA kernel launches would be asynchronous.