Fully_shard with 2D mesh (4,1) still runs all-gather / reduce-scatter on the shard dimension

Description

When passing a 2D mesh (4,1) to fully_shard, the shard dimension has world size 1, which is semantically equivalent to pure replication . However, the current implementation still takes the HSDP path and performs shard-dimension all-gather / reduce-scatter (the collectives degenerate to size=1), including copy-in/copy-out, event synchronization, and state transitions, which introduces extra overhead.

Code context (why this happens)

  • 2D meshes are always treated as HSDP MeshInfo with shard_mesh_dim=1, no special-casing for shard_size == 1.

  • Hooks still run unshard/reshard (pre_forward → unshard → wait_for_unshard, post_forward → reshard).

  • Collectives are invoked using the shard process group; even with world size 1 they still go through copy-in/copy-out paths.

Questions for maintainers

  1. Is running shard all-gather/reduce-scatter when the shard dimension size is 1 an intentional design choice?

  2. If yes, what is the rationale (e.g., keeping the FSDP/HSDP state machine uniform, simplifying implementation, future dynamic mesh compatibility, etc.)?

  3. Are there plans to add a degenerate fast path or a warning/doc note for shard_size == 1?