I’m not very familiar with lightning but in the PyTorch distributed package we don’t currently have a framework that is exactly for this paradigm. Distributed Data Parallel will parallelize the forward function and each loss is computed locally on each node, but this is for a replicated model and not a multi-task model as I believe you are asking for. A way to accomplish this on your own would be to use Distributed RPC Framework — PyTorch master documentation and handle the parallelization and tensor communication on your own via remote procedure calls