How to differentiabl-y alter weights in a torch.nn.Module

hey i have a question that might be weird.

given a model, either initialized with some weights or pretrained, I want to do the following:

  1. access weight of torch.nn.Module model m, w0
  2. w1 = f(w0, u), do modification of weight based on aux var u
  3. compute grad w.r.t u, namely dL(m(w1)))/dx

Neither

  • model.state_dict
    or
  • model.parameters(),

works, because state_dict does not allow gradient to pass through, and you can’t assign value to parameters in model.parameters() as model.parameters() are rather copies of the parameter, assign to them does not change the value of parameters used by the model.

any help would be appreciated, thanks in advance!