How to preserve backward grad_fn after distributed operations

Sorry for my late reply.

I tried your advice and then applied to my own model, it works! Thank you for your help. Actually I don’t know how do you implement your model parallelism, here I use distributeddataparallel in pytorch to distribute the model to different gpus of one node. So based on my experiment, I think maybe your work can also solve the distributed gpu grad_fn gathering problem? like in Will "dist.all_gather" break the auto gradient graph?. Thank you again.