How to preserve backward grad_fn after distributed operations

I am trying to implement model parallelism in a distributed cluster setting.

Let’s say I have a tensor tensor in each process and a number of operations have been performed on it (in each process independently). The tensor has a .grad_fn attached to it. Now I want to perform an all_gather. so that I create a list [tensor_1, tensor_2...tensor_n]. Then I can concatenate all those tensors using torch.cat. All the tensors in the list will lose the grad_fn property. My expectation is that process i will maintain the grad_fn for tensor_i in the list. It’s ok if all the others are lost. I want to be able to backward() after torch.cat in each process i through tensor_i. How can I achieve that? Any help is appreciated!

EDIT: I think I can just do tensor_list[dist.get_rank()] = tensor after the all_gather operation but I am not sure if there is a better way. Help?

Would it be possible to manually assign the grad_fn back to tensor_i?

I don’t think it’s a good idea to retain gradient functions on output tensors of collective functions. If anything, it would give an expectation of this working well out of the box, which is not the case. I think a better solution would be to stitch things together with torch.autograd.grad yourself, before and after the collectives.

Do you have any idea? Do you want to calculate tensor_i in different process but accumulate between the processes so the loss will be attained by all the tensor_i?

I’ve built this package that does this automatically now: https://github.com/ag14774/diffdist. So this question can be marked as solved

That’s really cool! Thanks for your sharing!
While I am not sure how the package works and whether it can be applied to such problem:

for iteration, data0, data1 in enumerate(data_loader, start_iter):
    tensor = model(data0)
    synchronize()
    tensors = dist.all_gather(tensor)
    loss = model(data1, tensors)

So in each process different data0 will generates a tensor, and the gathered tensors will be used for further training. Since ‘all_gather’ cannot preserve the ‘grad_fn’, can you give me some advice to solve it?
Thanks a lot.

Yes the package can do that. However, tensor needs to be of same shape and size in all processes. Then you can do something like:

for iteration, data0, data1 in enumerate(data_loader, start_iter):
    tensor = model(data0)
    synchronize()  # You probably do not need this since all_gather will force a sync
    gather_list = [torch.empty_like(tensor) for i in range(dist.get_world_size())]
    gather_list = diffdist.functional.all_gather(gather_list, tensor)
    loss = model(data1, gather_list)

Keep in mind though that all_gather is not very fast because its backprop involves running dist.reduce multiple times. When pytorch adds support for reduce_scatter, I will update the package to speed up the backprop.

Thank you so much for your help. :wink:
I tried the code, but the gather_list after diffdist.functional.all_gather(gather_list, tensor) also doesn’t contain each tensor’s grad_fn.

I found there is a parameter self.next_backprop in your code, do I need to set it? Sorry to bother you again.

Apologies, the line should be

gather_list = diffdist.functional.all_gather(gather_list, tensor)`

If you get any errors try setting inplace=False. No need to use next_backprop

Thank you again. one final question, I write a simple example to see how the grad_fn works:

    # in each process:
    a = torch.tensor([1.0, 3.0], requires_grad=True).cuda()
    b = a + 2 * dist.get_rank()
    # gather
    bs = [torch.empty_like(b) for i in range(dist.get_world_size())]
    bs = diffdist.functional.all_gather(bs, b)
    # loss backward
    loss = (torch.cat(bs) * torch.cat(bs)).mean()
    loss.backward()
    print(a.grad)

I think a should has its gradient? But currently it is None. I am a little bit lost.

You are right it seems to be working for CPU but not for CUDA for some reason. I will investigate a bit more. Feel free to open a pull request if you find the problem

I found the problem. The package is working fine. The problem is that when you set requires_grad=True you set it on the CPU version of a. Then you called cuda() which created another node in the graph. Gradient will pass through the GPU tensor a and then be accumulated to the CPU version of the tensor since that is the one that has requires_grad set to true. What you should do is torch.tensor([1.0, 3.0], requires_grad=True, device='cuda'). In a realistic scenario with normal training this won’t be a problem.

Sorry for my late reply.

I tried your advice and then applied to my own model, it works! Thank you for your help. Actually I don’t know how do you implement your model parallelism, here I use distributeddataparallel in pytorch to distribute the model to different gpus of one node. So based on my experiment, I think maybe your work can also solve the distributed gpu grad_fn gathering problem? like in Will "dist.all_gather" break the auto gradient graph?. Thank you again.

Glad it works!

Yes it seems that diffdist can handle that case. Of course different processes will have different computational graphs but with diffdist some nodes are inserted in the graph that will cause them to sync and communicate with each other. For example, doing a Send operation will cause a Recv to be called during backward in order to receive the gradient.