I am trying to implement model parallelism in a distributed cluster setting.

Let’s say I have a tensor `tensor`

in each process and a number of operations have been performed on it (in each process independently). The tensor has a `.grad_fn`

attached to it. Now I want to perform an `all_gather`

. so that I create a list `[tensor_1, tensor_2...tensor_n]`

. Then I can concatenate all those tensors using `torch.cat`

. All the tensors in the list will lose the `grad_fn`

property. My expectation is that process i will maintain the `grad_fn`

for `tensor_i`

in the list. It’s ok if all the others are lost. I want to be able to backward() after `torch.cat`

in each process i through `tensor_i`

. How can I achieve that? Any help is appreciated!

EDIT: I think I can just do `tensor_list[dist.get_rank()] = tensor`

after the all_gather operation but I am not sure if there is a better way. Help?