Implement forward hook that has different behavior during train and eval mode


I am trying to record some statistics in a list for both training and validation step. However, doing something like following will mix stat from training and validation into the same list.

stat_list = []
def hook(self, input, output):
  # compute stat


Since I know PyTorch has model.train() and model.eval(), I am wondering if I could specify behavior of hook using certain flag returned by model.train() and model.eval(). In my use case, it would be like

train_stat_list = []
val_stat_list = []
def hook(self, input, output):
  if train_flag:
   # compute stat
  if val_flag:
   # compute stat

Just figured this out by myself :-D.

It turns out there is a variable in nn.Module called training and model.train() and model.eval() essentially makes it True/False (default True)

from torch import nn

model  = nn.Module()
# True
# False

Yes, the .training flag on the module is what should be used to know whether you are in training or eval mode.