RPC for model parallelism increase GPU memory usage

Hello, I’m using RPC for applying model parallelism and I don’t see any kind of reduction in the memory usage. Even more, the memory usage is doubled!

This is the code I’m executing with RPC + Torchrun to use 3 nodes (1 GPU per node): 1 master + 2 workers

import random
import os
import time
import gc

import segmentation_models_pytorch as smp
import torch
import torch.optim as optim
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import numpy as np

import torch.distributed as dist
from torch.distributed.nn import RemoteModule
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions
########################


seed_value=42
torch.manual_seed(seed_value)
np.random.seed(seed_value) # cpu vars
random.seed(seed_value) # Python
torch.use_deterministic_algorithms(True,warn_only=True)
torch.cuda.manual_seed_all(seed_value) # gpu vars
os.environ['CUBLAS_WORKSPACE_CONFIG']=":4096:8"
torch.backends.cudnn.deterministic = True  #needed
torch.backends.cudnn.benchmark = False


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class RPCMyModel(torch.nn.Module):
    def __init__(self, remote_conv1, remote_conv2):
        super(RPCMyModel, self).__init__()
        # Define your layers here, for example:
        self.remote_conv1 = remote_conv1
        self.relu = torch.nn.ReLU()
        self.remote_conv2 = remote_conv2

    def forward(self, x):
        x = rpc.rpc_sync("worker1", self.remote_conv1.forward, args=(x,))
        x = torch.stack(x)  # x comes unstacked when using cuda
        x = self.relu(x)
        x = rpc.rpc_sync("worker2", self.remote_conv2.forward, args=(x,))
        x = torch.stack(x).cuda()
        return x


def init_worker(rank, world_size):
    rpc_backend_options = TensorPipeRpcBackendOptions(
        init_method = f"tcp://{os.environ['MASTER_ADDR']}:52355",
        num_worker_threads=world_size,
    )
    
    # Master
    if rank == 0:
        print("init rpc")
        rpc.init_rpc(
            "master",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )
        
        print(rank, "remote module")
        remote_conv1 = RemoteModule(
            "worker1/cuda",
            torch.nn.Conv2d,
            args=(1, 32),
            kwargs={"kernel_size":3, "stride":1, "padding":1}
        )

        remote_conv2 = RemoteModule(
            "worker2/cuda",
            torch.nn.Conv2d,
            args=(32, 1),
            kwargs={"kernel_size":3, "stride":1, "padding":1}
        )
    
        remote_modules = [remote_conv1, remote_conv2]
        
        # Initialize the model, loss function, and optimizer
        model = RPCMyModel(*remote_modules)
    
        # Retrieve all model parameters as rrefs for DistributedOptimizer.
        model_parameter_rrefs = model.remote_conv1.remote_parameters()
        for param in model.relu.parameters():
            model_parameter_rrefs.append(RRef(param))
        model_parameter_rrefs.extend(model.remote_conv2.remote_parameters())
        
        criterion = torch.nn.MSELoss()
        # Setup distributed optimizer
        optimizer = DistributedOptimizer(
            optim.Adam,
            model_parameter_rrefs,
            lr=0.001,
        )
        # Training loop
        num_epochs = 20
        # Generate a random tensor of size (2, 1, 512, 512)
        input_tensor = torch.rand((2, 1, 512, 512))
        # Create a target tensor of zeros with the same size as the input
        target_tensor = torch.zeros_like(input_tensor, device=DEVICE)
        
        for epoch in range(num_epochs):
            with dist_autograd.context() as context_id:
                # Forward pass
                output = model(input_tensor)
        
                # Calculate the mean squared error (MSE) loss
                loss = criterion(output, target_tensor)
            
                # # Backward pass and optimization
                dist_autograd.backward(context_id, [loss])
                optimizer.step(context_id)
                
            # Print the loss for every few epochs
            if (epoch + 1) % (num_epochs//5) == 0:
                print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
            
    elif rank > 0:  # in [1,2]:
        # dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)  
        # Initialize RPC.
        worker_name = "worker{}".format(rank)
        print(rank, "init rpc")
        rpc.init_rpc(
            worker_name,
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )
        print(rank, "pos rpc")
        # Worker just waits for RPCs from master.
  
    rpc.shutdown()
    print(rank, "RPC shutdown.")
####


if __name__ == "__main__":
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])

    print(rank, world_size)

    init_worker(rank, world_size)

And the command line on each node:

TP_SOCKET_IFNAME=<interface> GLOO_SOCKET_IFNAME=<interface> time torchrun --nnodes=3 --nproc-per-node=1 --master-addr=<DNS NAME> --node_rank=<From 0 to 2> --master-port=52355 --start-method=spawn rpc-torchrun.py

Executing without RPC, GPU memory is ~700MB (verified with pynvml) but, when using RPC, memory rise up to ~600MB only with the RemoteModule creation and doubles up to ~1700MB when executing the forward calls.

I’m using nvidia-smi to watch the memory usage in RPC. I know it’s not a 100% reliable source for that but I think the change in memory is significant to trust that something weird is happening.

I have no idea from where this huge memory increase could come. Is this normal or is a problem with my code/workflow in RPC?

Thanks,

.

The code for sequential execution:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os

from pynvml import *

seed_value=42
torch.manual_seed(seed_value)
np.random.seed(seed_value) # cpu vars
random.seed(seed_value) # Python
torch.use_deterministic_algorithms(True,warn_only=True)
torch.cuda.manual_seed_all(seed_value) # gpu vars
os.environ['CUBLAS_WORKSPACE_CONFIG']=":4096:8"
torch.backends.cudnn.deterministic = True  #needed
torch.backends.cudnn.benchmark = False


# Define a simple neural network model
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # Define your layers here, for example:
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x

# Initialize the model, loss function, and optimizer
model = MyModel()
model = model.cuda()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Generate a random tensor of size (2, 1, 512, 512)
input_tensor = torch.rand((2, 1, 512, 512))
input_tensor = input_tensor.cuda()

# Create a target tensor of zeros with the same size as the input
target_tensor = torch.zeros_like(input_tensor)
target_tensor = target_tensor.cuda()

# Training loop
num_epochs = 100
nvmlInit()
for epoch in range(num_epochs):
    h = nvmlDeviceGetHandleByIndex(0)
    
    # Forward pass
    output = model(input_tensor)

    # Calculate the mean squared error (MSE) loss
    loss = criterion(output, target_tensor)

    # # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print the loss for every 10 epochs
    if (epoch + 1) % 10 == 0:
        i0 = nvmlDeviceGetMemoryInfo(h)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
        print(f'model used  : {round(i0.used/1024**3,2)}\n')

For some weird reason, upgrading the GPU driver from version 520 to 535 “”“fix”“” this issue. I guess I can mark this as “solved” but I don’t know if this count as a real solution…