Problem
I am trying to record some statistics in a list for both training and validation step. However, doing something like following will mix stat
from training and validation into the same list.
stat_list = []
def hook(self, input, output):
# compute stat
stat_list.append(stat)
model.<layer>.register_forward_hook(hook)
Since I know PyTorch has model.train()
and model.eval()
, I am wondering if I could specify behavior of hook using certain flag returned by model.train()
and model.eval()
. In my use case, it would be like
train_stat_list = []
val_stat_list = []
def hook(self, input, output):
if train_flag:
# compute stat
train_stat_list.append(stat)
if val_flag:
# compute stat
val_stat_list.append(stat)