Using a model to filter gradients from a classifier

System explanation

What I’m trying to accomplish is a system where you have two models; a classifier, and what can call a “gradient_model”.

The goal of the classifier is learn features from a training dataset as normal.

The goal of the gradient_model is to learn to filter/aggregate/transform the gradients from the classifier so that the classifier performs well on the validation dataset; explicitly without having any feature information from the validation set.


Algorithm explanation

  1. Get train_batch
  2. Compute loss for classifier on train_batch
  3. Compute gradients w.r.t classifer weights and the loss for every environment
    This should result in a tensor of size [num_envs, …]
  4. Update the classifier’s weights by applying the gradient model
    new_weight = old_weight - gradient_model(gradients).squeeze()
    The gradient_model converts the [num_envs, …] tensor to [1, …] which we then squeeze.
  5. Get val_batch
  6. Compute val_loss for new_classifier on val_batch
  7. Inorder to update the gradient_model, i think we need to:
    a) compute the gradient_of the of the new model and val_loss
    b) as the new model depends on the gradient_model, we can then compute the gradient of the gradient_model and a)
  8. Update the gradient model

The problem

In order for me to train the gradient model, I think I need to train it using the gradient from the updated weights of the classifier.

Using torchviz, I can see that the gradient breaks between step 4 where we created new classifier weights and step 5 where we want to use the new classifier weights on a new forward pass.

If possible, I would like the gradient tape to continue onto a new forward pass of the model.