Question regarding parameters for optimizer

Hello,
In one of my deep learning applications, I am modifying the gradients after calling loss.backward. So the gradients are actually modified from their original. Now when i call optimizer.step(), does the optimizer consider the updated parameters? To demonstrate, here is a snippet of my code:

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, amsgrad=True)

And let the function:

def change_gradients(net):
      for param in net.parameters():
                some operation on param

Now inside the training loop:

        model.zero_grad()
        loss.backward()
        change_gradients()
        optimizer.step()

So in that case, does the optimizer take the updated parameters or the original parameters when it applied the update to the weights?

If no, then how can I update the optimizer parameters? If i re-define the optimizer after change_gradients() with the model being returned from this function, as in:

def change_gradients(net):
      for param in net.parameters():
                some operation on param
      return model

Now inside the training loop:

        model.zero_grad()
        loss.backward()
        model = change_gradients()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, amsgrad=True)        
        optimizer.step()

Would that work? How can the optimizer know the previous values of mt and vt if it is being re-defined every time and therefore everything in it is lost (except that now it receives the new model)?

Thanks!

This is related to this thread.

The optimizer have a reference to the model’s parameters, so once you change it anywhere, the optimizer will take it into account. The following snippet shows that the optimizer’s parameters point to the model’s parameters data:

import torch
model = torch.nn.Linear(2, 3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, amsgrad=True)
x = set(p.data_ptr() for p in model.parameters())
y = set(p.data_ptr() for p in optimizer.param_groups[0]['params'])
print(x == y) # prints True
1 Like