I am trying to implement model parallelism in a distributed cluster setting.
Let’s say I have a tensor
tensor in each process and a number of operations have been performed on it (in each process independently). The tensor has a
.grad_fn attached to it. Now I want to perform an
all_gather. so that I create a list
[tensor_1, tensor_2...tensor_n]. Then I can concatenate all those tensors using
torch.cat. All the tensors in the list will lose the
grad_fn property. My expectation is that process i will maintain the
tensor_i in the list. It’s ok if all the others are lost. I want to be able to backward() after
torch.cat in each process i through
tensor_i. How can I achieve that? Any help is appreciated!
EDIT: I think I can just do
tensor_list[dist.get_rank()] = tensor after the all_gather operation but I am not sure if there is a better way. Help?