Intermediate params update; set as a leaf variable

Hello,

I wish to create intermediate trainable parameters in my network. I am able to collect intermediate gradients with the help of the retain_grad() method, and perform updates manually, but I’d prefer to work with the optim module PyTorch provides so that I don’t have to write my own functions for the various gradient updates there are out there.

I’m relatively new to PyTorch, and my understanding is that the updates on the intermediate params are not working since they are non leaf nodes. Is there a way to detach previous computations to create these leaf nodes, so that optimizer.step() can perform updates successfully? Or is there a way to pass these intermediate gradients into the optimizer object itself?

The code below gives a vague illustration of what I am trying to accomplish. Any help is greatly appreciated!

network = MyNetwork()

fc_weights = fc_weights = torch.tensor([[1,0],[0,1]],requires_grad=False,dtype=torch.float)
input = torch.tensor([[2],[3]],requires_grad=False,dtype=torch.float) # sample output of several layers that are frozen

params = torch.tensor([[0.25],[0.125]],requires_grad=True,dtype=torch.float) # learnable parameters for first final layer
params2 = torch.rand(2,1,requires_grad=True) # learnable parameters for second final layer
params3 = torch.rand(2,1,requires_grad=True) # learnable parameters for third final (final) layer

learning_rate = 1000 # just some sample learning rate to probe changes
optimizer = optim.SGD([params], lr=learning_rate)
optimizer2 = optim.SGD([params2],lr=learning_rate)
optimizer3 = optim.SGD([params3],lr = learning_rate)
optimizer.zero_grad() # zero gradients initially to be safe
optimizer2.zero_grad()
optimizer3.zero_grad()

print(params) # check value before nudge
print(params2) # check value before nudge
print(params3) # check value before nudge
print()

x = torch.hstack((input,params))
params2 = network(x,fc_weights)
params2.retain_grad() # keep intermediate grads
print(params2.is_leaf) # prints False
params3 = network(params2,fc_weights)
params3.retain_grad() # keep intermdiate grads
print(params3.is_leaf) # prints False
print()

loss = MyLoss(params3)
loss.backward()

print(params.grad) # intermediate gradients successfully computed
print(params2.grad) # intermediate gradients successfully computed
print(params3.grad) # intermediate gradients successfully computed
print()

optimizer.step() # gradient update/ nudge
optimizer2.step() # gradient update/ nudge
optimizer3.step() # gradient update/ nudge

print(params) # this learnable parameter is updated successfully because it is a leaf node.
print(params2) # this learnable parameter is not updated, which is what I am trying to figure out
print(params3) # this learnable parameter is not updated, which is what I am trying to figure out

You are re-defining (or re-assigning) params2 and params3 again in these lines.
They are not holding the variables that you defined initially.

@InnovArul
Appreciate the response. What’s wrong with reassigning them new values? Isn’t it stored as the same object internally? Reason I do it that way is because I want param, param2, and param3 to undergo a nested set of computations; they should pick up on the computations they undergo relative to where they’re placed.

Initially params2 holds the reference of torch.rand(2,1,requires_grad=True). When you assign again, it holds the reference of some intermediate value that’s computed by model. So, the answer is No. params2 is not holding the same object.
Also, your usage of params2 tells me that it is just an intermediate variable and not something that you want to optimize.
Maybe you have to elaborate a bit more on what is your end goal with context.

Like you said, they “will” pick-up on the computations and pass on the gradients to previous computations. That is definitely handled by pytorch. But you cannot optimize params2, as it is an intermediate value calculated by model (i.e., it will change for different inputs).

With the current code example and definitions of params*, it is not understandable what is expected from pytorch (at least not understandable to me).

@InnovArul

I actually do intend on having params, params2 and params3 be optimized, and not just stored as intermediate placeholder variables. They culminate in an unsupervised loss function MyLoss. The gradients obtained via retain_grad() generate the partial derivatives MyLoss wrt params, MyLoss wrt params2 and MyLoss wrt params3. Taking SGD as an example, the update for params2 would be:
params2 -= lr*(MyLoss wrt params2).

Where params2.data is set to what was during instantiation (before being affected by the intermediate pass). Ditto for the other params.

But again, I wish to abstract this away by using the optim class instance. I think I see what you mean by reassigning values creates different objects, as the optimizer loses track of params2 for example after the reassignment (please correct me if I’m wrong).

The reason why I reassign params* is because I wish to first instantiate them outside my training loop and pass them into an instance of a custom model I’ve built. Loosely speaking, what I’m trying to do is train a standard neural network, then freeze it (the trained fully connected weights), and train another model that takes these frozen weights and trains another set of parameters. These new parameters (the params*) act as activations similar to a fully connected net, but they are also trainable (they get corrected according to minimizing myLoss).

Please let me know if this is enough, and whether you can infer a solution to my problem. If not, I can provide additional details. I really appreciate your help!

Inside custom model __init__, you can define these params* as instances of nn.Parameter and use it in your computation (new model’s forward) along with freezed weights from your earlier model.
Use optimizer to optimize new_model.parameters(). Will this work for you?

But remember that you can’t reassign params* in any case. You can use them in computation as nn.Parameter. Optimizer will take care of updating the `params* with its gradients automatically.