I am writing a program with essentially the following update rules:
I have the following simplified code:
import torch from torch import optim def f(x, y): return (x + y)**2 def g(x, y): return x * y x = torch.tensor([3.], requires_grad=True) y = torch.tensor([4.], requires_grad=True) x_optim = optim.SGD([x], lr=1.) y_optim = optim.SGD([y], lr=1.) ddx, = torch.autograd.grad(f(x, y).mean(), x, create_graph=True) # 2(x + y) = 14 ddx.mean().backward() # x.grad = d^2/dx^2 f(x, y) = 2 # y.grad = d/dy d/dx f(x, y) = 2 ddx, = torch.autograd.grad(g(x, y).mean(), x) x.grad = ddx # x.grad = d/dx g(x, y) = y = 4 y_optim.step() # y = 2 # x = 3 x_optim.step() # y = 2 # x = -1
My question is: is this the best (and most performant) way to do this?
x.grad = ddx is not so bad, but when it’s all the parameters of several neural networks, this involves a lot of careful matching of gradients to variables and a lot of room for error.