According to the documentation torch.distributed.tensor.parallel.SequenceParallel
is supposed to shard on the sequence dimension i.e. [B, T, C] -> [B, T//_world_size, C]
but it seems to be tiling instead i.e. [B, T, C] -> [B, T*_world_size, C]
. What am I missing?