I’m trying to train a function f
of weights w
that takes in some inputs x
and outputs y
. The loss function is MSELoss()(y, target)
. Now I want to optimize f
such that given a random x
input, we can optimize x
so as to minimize the loss function. This means that for a random initial guess, f
should be optimized so as to most quickly find the solution x
through gradient updates. The code looks like this:
initial_x = torch.Tensor()
opt = optim.Adam([initial_x], lr=0.001)
opt2 = optim.Adam([w], lr=0.001)
for i in range(10):
y = f(w, initial_x)
loss = MSELoss()(y, target)
initial_x.grad = torch.autograd.grad(loss, initial_x, retain_graph=True, create_graph=True)[0]
opt.step()
loss = MSELoss(f(w, initial_x), target)
opt2.zero_grad()
loss.backward()
opt2.step()
however, this would make the occupied memory to grow. any fix?