Correct way to freeze a part of a model

I have a (outer) model that contains a (inner) backbone. Normally, they both train. Sometimes, I want to freeze the backbone. As far as I understand, this means:

  1. Once at the beginning - iterate over all parameters and set their requires_grad to False
  2. Make sure that the model is always set to .eval() and not .train(), to make sure it does not do dropout etc.
    The second part is what I’m concerned with. What’s the correct way to set the inner model to permanently be on .eval(), such that calling .train() on the outer model does not undo it? e.g.
import torch
from torch import nn

class Inner(nn.Module):
    def __init__(self):

        self.fc = nn.Linear()

class Outer(nn.Module):
    def __init__(self):

        self.inner = nn.Module()

outer = Outer()
print( # is False as desired

print( # is True, not desired

I could override the .train() function of the inner model and make it a no-op, but that seems hacky. What’s the best practice?

If you try outer.inner.eval() after outer.train(), does that work?

That works, but what I want is the cleanest option to mark a module as frozen and not have to think about it again (and having to remember calling outer.inner.eval() after each call to outer.train() is the opposite of that). I could also override outer.train() to always call outer.inner.eval() after it finishes running, but that also doesn’t seem very clean.

The only option I can think of that would be ‘cleaner’ might be to define your own class member function(s) to do this. And inside that you can set the .eval() and train() for inner and outer as needed. Because you are doing something somewhat atypical, I don’t feel there would be any other way to do it, using just existing PyTorch functions.

tensor.detach() might be closer to what you’re looking for

Thanks, but that removes some of the abstraction - the user can no longer use the outer model without knowing about its parts. Ideally, everything would look the same whether inner is a 1. trainable model, 2. a frozen mode or 3. a function - but in your suggestion, as far as I understand, 2. requires different usage.

I’m familiar with tensor.detach() in other cases, could you please elaborate how using it here helps?