Question about tensor parallel (DTensor, parallelize_module)


I wanted to try the new tensor parallel framework. One thing that puzzles me is that it seems to consume more memory on rank 0 than without parallelize_module - where all ranks do the same thing. The code I use is adapted from the colab linked here: [RFC] PyTorch DistributedTensor · Issue #88838 · pytorch/pytorch · GitHub

with model = parallelize_module(…),I get for max memory (lparallel=True) :

0, max memory alloc: 980238336
1, max memory alloc: 117460992
2, max memory alloc: 117460992
3, max memory alloc: 117460992

and without (lparallel=False)

0, max memory alloc: 658613248
1, max memory alloc: 658613248
2, max memory alloc: 658613248
3, max memory alloc: 658613248

Is this expected? How can I reduce memory consumption on rank 0?

Thank you very much!

import torch.nn as nn
from torch.distributed.tensor.parallel import (

import torch
import torch.distributed as dist
from torch.distributed._tensor import DeviceMesh #, DTensor, Shard, Replicate, distribute_tensor

from torch.testing._internal.common_distributed import (


class ToyModel(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(ToyModel, self).__init__()
        self.dummy_param = nn.Parameter(torch.empty(0))
        self.net1 = nn.Linear(in_channels, hidden_channels)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(hidden_channels, in_channels)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_tp(world_size):
    Main body of the demo of a basic version of tensor parallel by using
    PyTorch native APIs.
    rank = dist.get_rank()
    print("Create a sharding plan based on the given world_size", rank)

    mesh = torch.arange(world_size)
    print(f"used mesh: {mesh}")
    # create a sharding plan based on the given world_size.
    device_mesh = DeviceMesh(

    in_dim = 1024
    hidden_dim = 4 * in_dim

    # create model and move it to GPU with id rank
    model = ToyModel(in_dim, hidden_dim).to(rank)

    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"pytorch_total_params {pytorch_total_params}, pytorch_total_params {pytorch_total_params}")

    # Create a optimizer for the parallelized module.
    LR = 0.25
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)

    if lparallel:
        print("Parallelize the module based on the given Parallel Style", rank)
        model = parallelize_module(model, device_mesh, PairwiseParallel())

    print(f"model of rank {rank} on {model.dummy_param.device}")

    # Perform a num of iterations of forward/backward
    # and optimizations for the sharded module.
    for i in range(ITER_TIME):
        inp = torch.rand(10000, in_dim).to(rank)
        output = model(inp)
        #print(f"FWD Step: iter {i}", rank)
        #print(f"BWD Step: iter {i}", rank)
        #print(f"Optimization Step: iter {i}", rank)
    #print("Training finished", rank)
    print(f'{rank}, max memory alloc: {torch.cuda.max_memory_allocated(device=rank)}')