The true size of DTensor (Distributed Tensor in Tensor Parallel)

print(x, x.shape, x.abs().sum())

DTensor(
    local_tensor=AsyncCollectiveTensor(tensor([[[ 0.1075, -2.3521,  0.0279, -1.4321,  1.8674,  0.0181, -1.0666,
           0.2343, -0.8745,  0.2822,  0.0908,  0.0743,  1.0759,  0.7753,
           1.4046, -1.2099, -0.4284,  0.3137, -0.4697, -0.8437,  1.3829,
          -0.5523,  1.3563,  0.7076, -0.4797,  0.8596,  0.0976, -0.7871,
           1.0103,  1.4723,  0.1845,  0.2991]]], device='cuda:1')), 
    device_mesh=DeviceMesh([0, 1], 
    mesh_dim_names=('tp',)), 
    placements=(Shard(dim=1),)
) 
torch.Size([1, 2, 32])  # shape still is the original one before sharding
DTensor(
    local_tensor=24.138418197631836, 
    device_mesh=DeviceMesh([0, 1], 
    mesh_dim_names=('tp',)), 
    placements=(_Partial(reduce_op=RedOpType.SUM),)
)

I have a question about the torch size of DTensor, x in the above example, which is sharded along with the row axis, from [1, 2, 32] to [1, 1, 32]. So the actual size should be [1, 1, 32], can be verified by print it. But the output from x.shape is [1, 2, 32]. Can I know why? Thanks in advance.

The true size of a DTensor can depend on several factors including its data type, the number of elements, and any additional metadata PyTorch stores. If you need precise memory usage, you might want to consider using the element_size() method and multiplying it by the number of elements in the tensor.

import torch

dtensor = torch.tensor([1.0, 2.0, 3.0])
size_in_bytes = dtensor.element_size() * dtensor.nelement()
print(f"Size of DTensor: {size_in_bytes} bytes")

Plus, tools like torch.cuda.memory_allocated() can help measure memory usage if you’re working with CUDA tensors.

according to the source code of DTensor, we can get its size by getting its local_tensor out of wrapper.

print(x)

DTensor(
    local_tensor=AsyncCollectiveTensor(tensor([[[ 0.1075, -2.3521,  0.0279, -1.4321,  1.8674,  0.0181, -1.0666,
           0.2343, -0.8745,  0.2822,  0.0908,  0.0743,  1.0759,  0.7753,
           1.4046, -1.2099, -0.4284,  0.3137, -0.4697, -0.8437,  1.3829,
          -0.5523,  1.3563,  0.7076, -0.4797,  0.8596,  0.0976, -0.7871,
           1.0103,  1.4723,  0.1845,  0.2991]]], device='cuda:1')), 
    device_mesh=DeviceMesh([0, 1], 
    mesh_dim_names=('tp',)), 
    placements=(Shard(dim=1),)
) 

print(x.shape)
torch.Size([1, 2, 32])  # this will output the original shape before sharding it

print(x._local_tensor.shape)
torch.Size([1, 1, 32])  # this gives the expected sharding shape