Making a module always non-training?

I’m making a special EMA module which should always be in no-grad eval() mode. However, it’s a sub-module of a trainable module, so when its parent has train() called, it gets its train() enabled. Is there a way to make sure it never has gradients and is always in eval() mode? All its params should have requires_grad false and all its inferences should be done as if it’s in eval mode.

Hi, @Sam_Lerman!

One good way to achieve this result is applying these restrictions directly during the forward pass (i.e. in the forward() method).

But one thing that you have to decide, based on your system and necessities, is if these are restrictions from the submodule itself or from the parent module.

It’s more a design choice, the functionality will be the same in both cases, but the difference are shown below.

As you didn’t provide any code, I’ll use some dummy code examples.
And as dummy codes, I didn’t test them, so I’m sorry for any mistake.

Hope that it helps you! :blush:

Restriction from submodule

Properties:

  • The submodule always run on eval mode and guarantees that no grads are going to be computed.
  • This means that there is no way to wrongly execute the submodule in train mode and/or with grads.
  • It obviously also means that you cannot do it even if you want to, so if you need to reuse this same submodule in another network without these restrictions, you should choose the next method.

Here is a code example:

class MySubmodule(nn.Module):
    def __init__(self):
        super().__init__()

        # instantiate some layers
        self.dummy_layer1 = DummyLayer()
        self.dummy_layer2 = DummyLayer()

    def forward(self, x):
        # run forward pass, applying the fixed
        # restrictions for the submodule
        with torch.no_grad():
            self.eval()
            x = self.dummy_layer1(x)
            x = self.dummy_layer2(x)
        return x

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

        # instantiate some layers, including the submodule
        self.dummy_layer1 = DummyLayer()
        self.dummy_layer2 = DummyLayer()

        self.my_submodule = MySubmodule()

        self.dummy_layer3 = DummyLayer()
        self.dummy_layer4 = DummyLayer()

    def forward(self, x):
        # run forward pass
        x = self.dummy_layer1(x)
        x = self.dummy_layer2(x)

        # use the submodule as any other layer
        x = self.my_submodule(x)

        x = self.dummy_layer3(x)
        x = self.dummy_layer4(x)
        return x

Restrictions from the parent module

Properties:

  • The submodule do not force any restrictions.
  • This means that it’s the job of each network implementation to decide what will be the behaviour of the submodule.
  • It’s the best option if you need to reuse this same submodule in another network without these restrictions.

Here is a code example:

class MySubmodule(nn.Module):
    def __init__(self):
        super().__init__()

        # instantiate some layers
        self.dummy_layer1 = DummyLayer()
        self.dummy_layer2 = DummyLayer()

    def forward(self, x):
        # run forward pass
        x = self.dummy_layer1(x)
        x = self.dummy_layer2(x)
        return x

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

        # instantiate some layers, including the submodule
        self.dummy_layer1 = DummyLayer()
        self.dummy_layer2 = DummyLayer()

        self.my_submodule = MySubmodule()

        self.dummy_layer3 = DummyLayer()
        self.dummy_layer4 = DummyLayer()

    def forward(self, x):
        # run forward pass
        x = self.dummy_layer1(x)
        x = self.dummy_layer2(x)

        # apply the restrictions for the submodule
        with torch.no_grad():
            self.my_submodule.eval()
            x = self.my_submodule(x)

        x = self.dummy_layer3(x)
        x = self.dummy_layer4(x)
        return x

Can you not override the train method for the EMA module?