How to optimize multiple variables with respect to multiple losses?

I want different losses to have their gradients computed with respect to different variables, and those variables to then all step together.

Here’s a simple example demonstrating what I want:

import torch as T
x = T.randn(3, requires_grad = True)
y = T.randn(4, requires_grad = True)
z = T.randn(5, requires_grad = True)

x_opt = T.optim.Adadelta([x])
y_opt = T.optim.Adadelta([y])
z_opt = T.optim.Adadelta([z])

for i in range(n_iter):

  shared_computation = foobar(x, y, z)

  x_loss = f(x, y, z, shared_computation)
  y_loss = g(x, y, z, shared_computation)
  z_loss = h(x, y, z, shared_computation)



My question is how do we do that backward_with_respect_to part in PyTorch? I only want x's gradient w.r.t. x_loss, etc… And then I want all the optimizers to step together (based on the current values of x, y, and z).

I’ve written a function to do just this. The two key components are (1) using retain_graph=True for all but the final call to .backward() and (2) saving grads after each call to .backward(), and restoring them at the end before .step()ing.

def multi_step(losses, optms):
  # optimizers each take a step, with `optms[i]`'s variables being 
  # optimized w.r.t. `losses[i]`.
  grads = [None]*len(losses)
  for i, (loss, optm) in enumerate(zip(losses, optms)):
    retain_graph = i != (len(losses)-1)
    grads[i] = [ 
            p.grad+0 for p in group['params'] 
          ] for group in optm.param_groups
  for optm, grad in zip(optms, grads):
    for p_group, g_group in zip(optm.param_groups, grad):
      for p, g in zip(p_group['params'], g_group):
        p.grad = g

In the example code stated in the question, multi_step would be used as follows:

for i in range(n_iter):
  shared_computation = foobar(x, y, z)

  x_loss = f(x, y, z, shared_computation)
  y_loss = g(x, y, z, shared_computation)
  z_loss = h(x, y, z, shared_computation)

  multi_step([x_loss, y_loss, z_loss], [x_opt, y_opt, z_opt])

If anyone could confirm that this is correct (and accounting for all cases), that would be great.

I haven’t verified your code, but would torch.autograd.grad work for your use case?
This would compute the gradients w.r.t. to the passed input.