Parameter update to minimize a loss that doesn't directly depend on said parameters

This sounds counterintuitive but let me explain.

I have 2 nn.Modules, let’s call them model and input_generator.
input_generator accepts a data point x as input, and outputs a vector x_new. x_new is in turn passed to the model, and model is trained to minimize the cross entropy loss between model(x_new) and target. In code:

# assume: input variable 'x', class 'target', optimizer on model.parameters
x_new = input_generator(x)
model_optimizer.zero_grad()
logits = model(x_new)
loss = F.cross_entropy(logits, target)
model_optimizer.step()

So far so good. Now the question is, how is input_generator trained? Well, I want input_generator to generate inputs that minimize a different objective for model. Say we have a completely unrelated input x_other. I want input_generator to generate x_new inputs to minimize the cross-entropy loss between model(x_other) and target. Expanding on the previous code snippet:

x_new = input_generator(x)
model_optimizer.zero_grad()
logits = model(x_new)
loss = F.cross_entropy(logits, target)
loss.backward()
model_optimizer.step()
# assume: unrelated input 'x_other', class 'other_target', optimizer on input_generator.parameters
ig_optimizer.zero_grad()
other_logits = model(x_other)
other_loss = F.cross_entropy(other_logits , other_target)
other_loss.backward()
ig_optimizer.step()  # this won't work of course

Now, you see the problem. There is no way to optimize input_generator with my “off-policy” objective, because input_generator does not play absolutely any role in the creation of x_other.

So, my dear community, is there a way to say: “Please, input_generator, generate x_new in a way that minimizes the cross-entropy loss between model and a seemingly unrelated input, once model has taken a gradient step with your generated x_new.”?

Thank you and all the best.

Hi @peustr
I don’t think there is a way to do that, that is at least theoretically right.

other_loss is a function of other_logits (and other_target but that is a constant hence not optimizable).
other_logits in turn is a function of model parameters only. So, when you call other_loss.backward(), gradients of other_loss with respect to model parameters shall be computed.

ig_optimizer.step() needs the gradients of other_loss wrt to the parameters of input_generator to be able to update them.
Parameters of input_generator have nothing to do with what value is taken by other_loss. That to say, you really cannot “optimise” the parameters of input_generator for other_loss to take an optimised value.

Let’s think of this using an example-

z = y*y
x = torch.tensor([1.0], requires_grad=True)

z is a function of tensor y. To minimize z, x is simply irrelevant.

Hi @srishti-git1110, thank you for your reply.

I understand completely what you are saying, but the example you provide is not a direct analogy.

If we think about it, there exists a path from x to other_loss that passes through x_new. The path is as follows:

  • x_new is generated by input_generator(x) e.g., a function with parameters φ
  • x_new is given to model, produces loss
  • loss is used to update model parameters θ_t to θ_{t+1}
  • given a constant input x_other, model with parameters θ_{t+1} produces other_loss
  • what is the x_new that would minimize other_loss, if model with parameters θ_t had used it to perform the gradient update instead?
  • update φ so that it produces a better x_new in the next iteration

I am aware that this may not be possible (although it seems like it could be, with second-order gradients or something); I am just demonstrating that the example you provided does not represent my problem fully.

Thanks again!

@peustr
Yes, I included that example just to convey that we cannot traditionally optimise a function wrt some variable that it mathematically does not depend on (directly or indirectly).

But, what you detailed in the bullet points makes sense to me. Thanks for the clarification.

The task now will be to put these points in form of mathematical equations to get a sense of how autograd could be utilised to update input_generator’s parameters.
I’ll get back to you on this if I’m able to figure something out.

1 Like

Hi Panagiotis!

If I understand your use case correctly, you would like to train input_generator
so that it produces samples of x_new which if used to train model with the
loss function cross_entropy (model (x_new), target) then model will
learn (from being trained with x_new) to minimize
cross_entropy (model (x_other), target).

Based on this understanding, one approach would be to train input_generator
to produce x_new such that cross_entropy (model (x_new), target) will
have a gradient that trains model in the way you want. To do this, you could
use the gradient of cross_entropy (model (x_other), target) as the
“ground-truth” target gradient for the gradient produced by
cross_entropy (model (x_new), target) and then train input_generator
to produce samples of x_new that produce such gradients.

In the example below, target_grad_weight and target_grad_bias are the
“target” gradients for the weight and bias of a simple Linear model.
We then compute x_new = input_generator (x) (also for a simple Linear
input_generator) and compute gen_grads, the gradient with respect to
model of cross_entropy (model (x_new), target), We then use the
mean-squared-error between gen_grads and the target gradients as the
loss function with which to train input_generator.

>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> model = torch.nn.Linear (3, 5)
>>> input_generator = torch.nn.Linear (2, 3)
>>> x_other = torch.randn (1, 3)
>>> target = torch.randint (5, (1,))
>>>
>>> loss_model_other = torch.nn.functional.cross_entropy (model (x_other), target)
>>> loss_model_other.backward()
>>> target_grad_weight = model.weight.grad.clone()
>>> target_grad_bias = model.bias.grad.clone()
>>> model.zero_grad()
>>>
>>> x = torch.randn (1, 2)
>>> x_new = input_generator (x)
>>> loss_model_new = torch.nn.functional.cross_entropy (model (x_new), target)
>>> gen_grads = torch.autograd.grad (loss_model_new, (model.weight, model.bias), create_graph = True)
>>>
>>> loss_input_generator  = torch.nn.functional.mse_loss (gen_grads[0], target_grad_weight)
>>> loss_input_generator += torch.nn.functional.mse_loss (gen_grads[1], target_grad_bias)
>>>
>>> loss_input_generator.backward()
>>> input_generator.bias.grad
tensor([ 0.3475, -0.2202, -0.0620])

We do, in fact, compute a second-order derivative of a sort in that we
compute the gradient of a function of a gradient and we use
autograd.grad (create_graph = True) as an intermediate step in
doing so.

Best.

K. Frank

Hi @KFrank, thank you so much for your reply, this is a very fresh perspective for me, I can’t wait to go to the office on Monday and try it out!
I will mark your response as the solution for now, because it makes total sense to me that it should work, but I will make sure to leave an update here soon. Cheers!