Distributed: scatter list of tensors of different sizes

Say I have two processes running, and that I have two tensors t0 and t1, with sizes 13 and 6, respectively, stored in the process with rank zero. How can I scatter the tensor list [t0, t1] to ranks zero and one? When I try using dist.scatter I get the error

RuntimeError: ProcessGroupGloo::scatter: invalid tensor size at index 1 (expected (13), got (6))

so it seems that dist.scatter expects the output tensor size to be fixed among all the processes.

Since t0 and t1 are stored on rank 0, you wouldn’t need to scatter t0 to itself right? In this case, I would just send from t1 from rank 0 and recv t1 from rank 1.

Regarding the documentation, I added an issue to track this to make it clearer for scatter() Update dist.scatter() documentation to provide clearer arg description · Issue #84566 · pytorch/pytorch · GitHub

You’re right, but in reality I have more than two processes; I simplified it to two for the sake of this post. Thanks for opening the issue.

Scatter expects all tensors to be of the same size. This can be addressed by padding and some additional comms.

The trick is to broadcast the max tensor size and then

You can solve this by broadcasting the max size among tensor sizes. Something like the following:

    is_rank0 = dist.get_rank() == 0
    tensors = [torch.rand(2), torch.rand(3), torch.rand(4), torch.rand(5)] if is_rank0 else None

    # find out the largest tensor that will be used for scattering
    max_len = max(t.numel() for t in tensors) if is_rank0 else 0
    print(f"{dist.get_rank()} local max len {max_len}")

    # broadcast the receive size
    tensor_len = torch.tensor([max_len])
    dist.broadcast(tensor_len, src=0)
    print(f"{dist.get_rank()} max len {tensor_len}")

    # scatter each rank's size
    tensors_len = [torch.tensor([t.numel()], dtype=torch.int64) for t in tensors] if is_rank0 else None
    my_len = torch.zeros(1, dtype=torch.int64)
    dist.scatter(my_len, tensors_len)
    print(f"{dist.get_rank()} my len is {my_len.item()}")

    # zero pad input tensors
    if tensors is not None:
        tensors = [pad(t, (0, max_len - t.numel()), "constant", 0) for t in tensors]

    receive_tensor = torch.empty(tensor_len)
    dist.scatter(receive_tensor, tensors)
    receive_tensor = receive_tensor[0: my_len]
    print(f"{dist.get_rank()} -> final len: {receive_tensor.numel()}")

This usage pattern is common enough that PyTorch should directly support it. Filed a feature request for it: Uneven and/or Dynamically sized collectives · Issue #84593 · pytorch/pytorch · GitHub