Hello, I’m using RPC for applying model parallelism and I don’t see any kind of reduction in the memory usage. Even more, the memory usage is doubled!
This is the code I’m executing with RPC + Torchrun to use 3 nodes (1 GPU per node): 1 master + 2 workers
import random
import os
import time
import gc
import segmentation_models_pytorch as smp
import torch
import torch.optim as optim
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import numpy as np
import torch.distributed as dist
from torch.distributed.nn import RemoteModule
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions
########################
seed_value=42
torch.manual_seed(seed_value)
np.random.seed(seed_value) # cpu vars
random.seed(seed_value) # Python
torch.use_deterministic_algorithms(True,warn_only=True)
torch.cuda.manual_seed_all(seed_value) # gpu vars
os.environ['CUBLAS_WORKSPACE_CONFIG']=":4096:8"
torch.backends.cudnn.deterministic = True #needed
torch.backends.cudnn.benchmark = False
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class RPCMyModel(torch.nn.Module):
def __init__(self, remote_conv1, remote_conv2):
super(RPCMyModel, self).__init__()
# Define your layers here, for example:
self.remote_conv1 = remote_conv1
self.relu = torch.nn.ReLU()
self.remote_conv2 = remote_conv2
def forward(self, x):
x = rpc.rpc_sync("worker1", self.remote_conv1.forward, args=(x,))
x = torch.stack(x) # x comes unstacked when using cuda
x = self.relu(x)
x = rpc.rpc_sync("worker2", self.remote_conv2.forward, args=(x,))
x = torch.stack(x).cuda()
return x
def init_worker(rank, world_size):
rpc_backend_options = TensorPipeRpcBackendOptions(
init_method = f"tcp://{os.environ['MASTER_ADDR']}:52355",
num_worker_threads=world_size,
)
# Master
if rank == 0:
print("init rpc")
rpc.init_rpc(
"master",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
print(rank, "remote module")
remote_conv1 = RemoteModule(
"worker1/cuda",
torch.nn.Conv2d,
args=(1, 32),
kwargs={"kernel_size":3, "stride":1, "padding":1}
)
remote_conv2 = RemoteModule(
"worker2/cuda",
torch.nn.Conv2d,
args=(32, 1),
kwargs={"kernel_size":3, "stride":1, "padding":1}
)
remote_modules = [remote_conv1, remote_conv2]
# Initialize the model, loss function, and optimizer
model = RPCMyModel(*remote_modules)
# Retrieve all model parameters as rrefs for DistributedOptimizer.
model_parameter_rrefs = model.remote_conv1.remote_parameters()
for param in model.relu.parameters():
model_parameter_rrefs.append(RRef(param))
model_parameter_rrefs.extend(model.remote_conv2.remote_parameters())
criterion = torch.nn.MSELoss()
# Setup distributed optimizer
optimizer = DistributedOptimizer(
optim.Adam,
model_parameter_rrefs,
lr=0.001,
)
# Training loop
num_epochs = 20
# Generate a random tensor of size (2, 1, 512, 512)
input_tensor = torch.rand((2, 1, 512, 512))
# Create a target tensor of zeros with the same size as the input
target_tensor = torch.zeros_like(input_tensor, device=DEVICE)
for epoch in range(num_epochs):
with dist_autograd.context() as context_id:
# Forward pass
output = model(input_tensor)
# Calculate the mean squared error (MSE) loss
loss = criterion(output, target_tensor)
# # Backward pass and optimization
dist_autograd.backward(context_id, [loss])
optimizer.step(context_id)
# Print the loss for every few epochs
if (epoch + 1) % (num_epochs//5) == 0:
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
elif rank > 0: # in [1,2]:
# dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
# Initialize RPC.
worker_name = "worker{}".format(rank)
print(rank, "init rpc")
rpc.init_rpc(
worker_name,
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
print(rank, "pos rpc")
# Worker just waits for RPCs from master.
rpc.shutdown()
print(rank, "RPC shutdown.")
####
if __name__ == "__main__":
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
print(rank, world_size)
init_worker(rank, world_size)
And the command line on each node:
TP_SOCKET_IFNAME=<interface> GLOO_SOCKET_IFNAME=<interface> time torchrun --nnodes=3 --nproc-per-node=1 --master-addr=<DNS NAME> --node_rank=<From 0 to 2> --master-port=52355 --start-method=spawn rpc-torchrun.py
Executing without RPC, GPU memory is ~700MB (verified with pynvml) but, when using RPC, memory rise up to ~600MB only with the RemoteModule creation and doubles up to ~1700MB when executing the forward calls.
I’m using nvidia-smi to watch the memory usage in RPC. I know it’s not a 100% reliable source for that but I think the change in memory is significant to trust that something weird is happening.
I have no idea from where this huge memory increase could come. Is this normal or is a problem with my code/workflow in RPC?
Thanks,
.
The code for sequential execution:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
from pynvml import *
seed_value=42
torch.manual_seed(seed_value)
np.random.seed(seed_value) # cpu vars
random.seed(seed_value) # Python
torch.use_deterministic_algorithms(True,warn_only=True)
torch.cuda.manual_seed_all(seed_value) # gpu vars
os.environ['CUBLAS_WORKSPACE_CONFIG']=":4096:8"
torch.backends.cudnn.deterministic = True #needed
torch.backends.cudnn.benchmark = False
# Define a simple neural network model
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# Define your layers here, for example:
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
return x
# Initialize the model, loss function, and optimizer
model = MyModel()
model = model.cuda()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Generate a random tensor of size (2, 1, 512, 512)
input_tensor = torch.rand((2, 1, 512, 512))
input_tensor = input_tensor.cuda()
# Create a target tensor of zeros with the same size as the input
target_tensor = torch.zeros_like(input_tensor)
target_tensor = target_tensor.cuda()
# Training loop
num_epochs = 100
nvmlInit()
for epoch in range(num_epochs):
h = nvmlDeviceGetHandleByIndex(0)
# Forward pass
output = model(input_tensor)
# Calculate the mean squared error (MSE) loss
loss = criterion(output, target_tensor)
# # Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print the loss for every 10 epochs
if (epoch + 1) % 10 == 0:
i0 = nvmlDeviceGetMemoryInfo(h)
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
print(f'model used : {round(i0.used/1024**3,2)}\n')