Hi,
Just switch to pytorch. Have a question here. Suppose a multi-task settings.
task1_preds, task2_preds = self.model(input)
task1_loss = self.crit(task1_preds, task1_labels)
task2_loss = self.crit(task2_preds, task2_labels)
I want to get the gradients of a tensor A wrt these two losses, like d task1_loss (A), d task2_loss(A)
task1_loss.backward(retain_graph=True)
A_task1_grad = A.grad
task2_loss.backward(retain_graph=True)
A_task2_grad = A.grad
Just wondering if this should work. Any ideas? Thanks.