Distributed Collectives

Hi! I am newer to pytorch and am working with the distributed package.

I am trying to implement some parallelism functionalities but am running into some issues:

  • I want to partition/shard a tensor. I know there’s some things hardcoded in that I also need to figure out, but for some reason dist.scatter is incorrectly producing [[1][3]] and [[2][3]] instead of [[1][3]] and [[2][4]]:
if local_rank == 0:
    weight = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).cuda(local_rank)
else:
    weight = None

world_size = dist.get_world_size(m_view_2)
chunks = None
if local_rank == 0:
    chunks = list(torch.chunk(weight, world_size, dim=1))

part_tensor = torch.empty((2, 1), dtype=torch.float32).cuda(local_rank) # need to figure out how to dynamically get shape

dist.scatter(part_tensor, scatter_list=chunks, src=0, group=m_view_2, async_op=False)

On a more general note, what i am trying to do is to take in a sequence of operations and execute those operations. For example:

  • partition input tensor across devices 0 and 1
  • replicate weights across devices 0 and 1
  • multiply input tensor and weight
  • reduce the results onto device 0
  • replicate the result across devices 0 and 1

I know there are some libraries out there that would make what I am outlining relatively simple, but my goal here is to be more explicit with each individual step instead of something like an all_reduce, etc. Not sure if anyone has thoughts on how I should approach this and any considerations I should probably have with it (async vs sync operations is something I likely have to be mindful of).

If you want to do more explicit, you can call reduce instead of all-reduce for " * reduce the results onto device 0". Is this something you want to achieve?

yes! For the each of the following operations, I believe the following collectives would be:

partitioning/sharding (ex. [[2, 3], [4, 5]] → [[2, 3]] and [[4, 5]]: scatter
replicating (ex. [[2,3]] → [[2, 3]] and [[2, 3]]: broadcast
combining (ex. [[2, 3]] and [[4, 5]] → [[2, 3], [4, 5]]): gather
reduce (ex. [[2, 3]] and [[4, 5]] → [[6, 8]]): reduce

Im still unsure as to why the example from my original post isn’t creating the right output and whether or not the following collective operations will cause issues with autograd/gradient tracking.

edit: was able to resolve the issue in the original post by adding:

if local_rank == 0:
    chunks = list(torch.chunk(weight, world_size, dim=1))

    for i in range(len(chunks)):
        chunks[i] = chunks[i].contiguous()

Not entirely sure what the implications of this would be on gradient tracking though as it seems to be making a new tensor.

I got confused, why do you use scatter? Is it because originally your tensor is only in rank0 so you will need to call a scatter? If so I think your understanding is correct.

because the input of collective needs to be contiguous, while when you do split or chunk, each split might still keep the stride of the original tensor