Differentiable Weight quantization

Hello all,

I want to ask in this case:
//
vgg16_qu.state_dict().update(quant_dict)
vgg16_qu.load_state_dict(quant_dict)
//
quant_dict is dictionary containing W= f(w). (W are the weights I use instead of w in my network), (weights of my original vgg16 network)=w.
Now I want gradients wrt my original parameters(w), even though the weights I use for forward pass are a differentiable representation(W= f(w)) of my earlier weights.

The above piece of code replaces the weights by W and also calculates the gradients wrt W instead of w, which is what I want.

How do I do this?
Can anyone help?

Thanks in advance.