Automatically check whether a module has different behavior for train() and eval() modes


so, is there a simple attribute / method or function that one could use to return a bool whether a module contains code that changes depending on the .train()-state or .eval()-state.

For instance, for batchnorm with running stats, there would be different behavior. In that case, I would like to have a function:

check_train_eval_difference(batchnorm()) -> True

For conv, which does not change:

check_train_eval_difference(conv()) -> False

Also, I forgot how it is typically called when the model shows different behavior. I read a specific-term to name this somewhere. Does anyone know?


Best, JZ

You can define your own function.

Maybe this can help you.

import torch
import torch.nn as nn

def check_train_eval_difference(module, input):
    with torch.no_grad():
        a = module(input)

        b = module(input)

    return (a!=b).all()

batch_norm = nn.BatchNorm2d(3)
conv = nn.Conv2d(3, 2, 2)

input = torch.randn(20, 3, 35, 45)

print(check_train_eval_difference(batch_norm, input))
print(check_train_eval_difference(conv, input))

# Output: 
# tensor(True)
# tensor(False)
1 Like

Hello Matias,

I was more thinking of a version of this function that does not require sending input through the layers, but can simply analyse a network in advance (before running) whether it contains layers with different train eval behaviour. Ideally, such layers should have a fixed boolean attribute that indicates this. But if there is no such thing, I will go with the function you provided. Thanks!

Best, JZ

1 Like

You can check the internal .training flag:

bn = nn.BatchNorm2d(3)
# True

# False
1 Like