Best practice for prediction using new_weights


First, I am first initializing model parameters, and predicting some input data on it.

pred = learner(adapt_data)
loss = loss_fn(pred, adapt_labels)
grad = torch.autograd.grad(loss, learner.parameters())
new_weight = list(map(lambda p: p[1] - 0.4*p[0], zip(grad, learner.parameters()))) # theta' = theta - alpha*grads

Now, I want to make use of these new_weights to predict on evaluation data without updating them in the model learner().

pred_val = learner(eval_data, new_weight)

But, this obviously won’t work as learner() takes only 1 argument (2, if we consider self). How should I work around this problem? One way I could think of was to make use of deepcopy(learner) to replicate the model and update the parameters of the new model, say new_learner(). But, the problem is, I can’t make use of loss.backward() or opt.step() as it will update the original model.

I would really appreciate if you could help

So, I tried something like this. I am not sure if this is correct, so do feel free to correct me in case it is a wrong practice.

I made a deepcopy of the original model. Let’s call it new_learner(). Similar to what I have done above, I computed the new weights in a variable new_weight. Now, I am made use of this to update without making use of opt.step():

with torch.no_grad():
    for i, (name, params) in enumerate(new_learner.named_parameters()):

Now, once I evaluate this new_learner on eval_data, I am deleting it using del new_learner
Next, I am using the loss generated from the prediction to update the original model’s weights.

pred_val = new_learner(eval_data)
loss_val = loss_fn(pred_val, eval_labels)
del new_learner

And it’s done! Please let me know if this is ok, or should I be aware of a better practice? Thank you!

Ok, so the problem with the above solution is that train_loss.backward() is not updating the parameters of the learner() model at all.

So, basically, the graph is breaking somewhere thus, learner.parameters()[0].grad is coming out to be None for every parameter.