Implementation of PCGrad in Pytorch Lightning


I would like to implement the paper [2001.06782] Gradient Surgery for Multi-Task Learning.
In summary, I am working on a multitask learning setting so there are more than one loss function. I need to calculate separate gradients for each of these losses and manipulate them depending on their cosine similarity.

In standart pytorch I can do it using torch.autograd.grad; however, after calculating these gradients and manipulating them I would like to pass them to optimizer (no need to recalculate). How can I do it in torch lightning? Can I call manual_backward for each loss in a single train_step, or can I pass the summed up calculated gradients to optimizer?

Thanks for the help