I have a question about the torch size of DTensor, x in the above example, which is sharded along with the row axis, from [1, 2, 32] to [1, 1, 32]. So the actual size should be [1, 1, 32], can be verified by print it. But the output from x.shape is [1, 2, 32]. Can I know why? Thanks in advance.
The true size of a DTensor can depend on several factors including its data type, the number of elements, and any additional metadata PyTorch stores. If you need precise memory usage, you might want to consider using the element_size() method and multiplying it by the number of elements in the tensor.