How to deactivate at convenience the effects of forward pre hooks during optimization?

Hello, I have a model which has attached to its weights something pretty similar to (weight normalization) pytorch/weight_norm.py at master · pytorch/pytorch · GitHub . In weight normalization the parameters are transformed into w = g/||v||*v and you update g and v during training.

suppose my parameters are a linear combination: w* = aw + b (where some pre-rained weights w are frozen, a and b are initialized in ones and zeros respectively; a and b require grads), When I do a forward/backward pass with the model using the hooks I get the gradients → grad(loss(aw + b))). I would like to get grad(loss(b)) → (the gradients of the model containing only the parameters b).

Currently to do that I make a copy of the network before adding the hooks with register_forward_pre_hook and initialize its weights as 0. then I copy the parameters b in every update from the big model (w* = aw + b) and paste it in that initial copy without hooks, then do a forward/backward pass to get grad(loss(b)). My problem is related to the efficiency, I would like to know if there is a way to suppress the effect of a and w during a new forward/backward computation so I could get the gradients of a network only using b during the optimization.

This moreover what I do:

model 1 : w* = aw + b (w pre-trained, a as ones, b as zeros)
model 2 : b (b as zeros)

  1. forward on model 1
  2. Take the state dict from model 1 and paste only the parameters b in model 2
  3. forward/backward on model 2 and get the gradients grad(loss(b))
  4. Update model 1 parameters