Hi,
I wanted to try the new tensor parallel framework. One thing that puzzles me is that it seems to consume more memory on rank 0 than without parallelize_module - where all ranks do the same thing. The code I use is adapted from the colab linked here: [RFC] PyTorch DistributedTensor · Issue #88838 · pytorch/pytorch · GitHub
with model = parallelize_module(…),I get for max memory (lparallel=True) :
0, max memory alloc: 980238336
1, max memory alloc: 117460992
2, max memory alloc: 117460992
3, max memory alloc: 117460992
and without (lparallel=False)
0, max memory alloc: 658613248
1, max memory alloc: 658613248
2, max memory alloc: 658613248
3, max memory alloc: 658613248
Is this expected? How can I reduce memory consumption on rank 0?
Thank you very much!
import torch.nn as nn
from torch.distributed.tensor.parallel import (
PairwiseParallel,
parallelize_module,
)
import torch
import torch.distributed as dist
from torch.distributed._tensor import DeviceMesh #, DTensor, Shard, Replicate, distribute_tensor
from torch.testing._internal.common_distributed import (
spawn_threads_and_init_comms,
)
ITER_TIME = 2
class ToyModel(nn.Module):
def __init__(self, in_channels, hidden_channels):
super(ToyModel, self).__init__()
self.dummy_param = nn.Parameter(torch.empty(0))
self.net1 = nn.Linear(in_channels, hidden_channels)
self.relu = nn.ReLU()
self.net2 = nn.Linear(hidden_channels, in_channels)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
@spawn_threads_and_init_comms
def demo_tp(world_size):
"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
"""
rank = dist.get_rank()
print("Create a sharding plan based on the given world_size", rank)
mesh = torch.arange(world_size)
print(f"used mesh: {mesh}")
# create a sharding plan based on the given world_size.
device_mesh = DeviceMesh(
"cuda",
mesh,
)
in_dim = 1024
hidden_dim = 4 * in_dim
# create model and move it to GPU with id rank
model = ToyModel(in_dim, hidden_dim).to(rank)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"pytorch_total_params {pytorch_total_params}, pytorch_total_params {pytorch_total_params}")
# Create a optimizer for the parallelized module.
LR = 0.25
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
lparallel=True
if lparallel:
print("Parallelize the module based on the given Parallel Style", rank)
model = parallelize_module(model, device_mesh, PairwiseParallel())
print(model)
print(f"model of rank {rank} on {model.dummy_param.device}")
# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
for i in range(ITER_TIME):
inp = torch.rand(10000, in_dim).to(rank)
output = model(inp)
#print(f"FWD Step: iter {i}", rank)
output.sum().backward()
#print(f"BWD Step: iter {i}", rank)
optimizer.step()
#print(f"Optimization Step: iter {i}", rank)
#print("Training finished", rank)
print(f'{rank}, max memory alloc: {torch.cuda.max_memory_allocated(device=rank)}')
WORLD_SIZE=4
demo_tp(WORLD_SIZE)