What I’m trying to accomplish is a system where you have two models; a classifier, and what can call a “gradient_model”.
The goal of the classifier is learn features from a training dataset as normal.
The goal of the gradient_model is to learn to filter/aggregate/transform the gradients from the classifier so that the classifier performs well on the validation dataset; explicitly without having any feature information from the validation set.
- Get train_batch
- Compute loss for classifier on train_batch
- Compute gradients w.r.t classifer weights and the loss for every environment
This should result in a tensor of size [num_envs, …]
- Update the classifier’s weights by applying the gradient model
new_weight = old_weight - gradient_model(gradients).squeeze()
The gradient_model converts the [num_envs, …] tensor to [1, …] which we then squeeze.
- Get val_batch
- Compute val_loss for new_classifier on val_batch
- Inorder to update the gradient_model, i think we need to:
a) compute the gradient_of the of the new model and val_loss
b) as the new model depends on the gradient_model, we can then compute the gradient of the gradient_model and a)
- Update the gradient model
In order for me to train the gradient model, I think I need to train it using the gradient from the updated weights of the classifier.
Using torchviz, I can see that the gradient breaks between step 4 where we created new classifier weights and step 5 where we want to use the new classifier weights on a new forward pass.
If possible, I would like the gradient tape to continue onto a new forward pass of the model.