How to exclude a model parameter from Model.parameters() without setting its required_grad to False

I want to do this since I need two optimizers to take care of different model parameters.

for example, my model has many layers and an architecture parameter model.arch:

class model(nn.Module):
  def __init__(self):
    ...
    self.linear_layers = nn.ModuleList()
    for i in range(4):
      self.linear_layers.append(nn.Linear(...))
    # many other layers
    self.arch = torch.zeros([...], requires_grad=True)
  
  def get_arch_parameters(self):
    return [self.arch]

  def get_other_parameters(self):
     # I want to return all parameters other than self.arch

Now I want to use two optimizers to update different parameters:

opt1 = torch.optim.SGD(model.get_other_parameters(), ...)
opt2 = torch.optim.SGD(model.get_arch_parameters(), ...)

However, I do not know how to implement model.get_other_parameters(), which returns all model parameters but model.arch.
I cannot simply set model.arch.requries_grad=False and use model.parameters() instead, since opt2 needs to update it.
Hence I’d like to know if there is any method to exclude a parameter from model.parameters() without setting its requires_grad to False.

Thanks in advance.

Can you simply return the ModuleList? Like return self.linear_layers. If I understood correctly these layers are not in the arch variable so it should work.

Hi Andrea,

Thanks. You inspired me.

Though it will not work to just return a list of layers since the parameter that an optimizer take should be an iterator of tensors, you remind me that each layer is actually a Module so it has its own parameters method. And I found that a ModuleList also has a parameter method.

So I just wrote:

def get_other_parameters(self):
  return itertools.chain(self.linear_layers.parameters(), other.parameters(), ...)

And seemingly it works.

However, it’s still a little clumsy since I need to enumerate all the parameters I want to optimize. In cases where the number of parameters is huge while I only want to exclude several of them, it is neither efficient nor elegant.

Hence I still want to know if there is a better solution.