Invoke forward hook function only once

Problem

I am trying to compute some statistics in intermediate representation of input data. It is natural to use .register_forward_hook to accomplish this goal. However, since the computation of this statistics quite has long runtime. I could only afford to compute it once (in my case, just first mini-batch of validation data).

The skeleton of hook function looks like

stat_list = []
def hook(self, input, output): 
   if model.training == False:
       # compute stat
       statistics_list.append(stat)

However, I am not sure how to do this correctly. I have tried two things

  • Access i in validation loop (for i, (X_val, y_val) in val_dataloader)
stat_list = []
for hook(self, input, outpu):
  if (model.training == False) and (i == 0):
    # do something
  • Set a global flag and modify it within the loop.
flag = True
stat_list = []
for hook(self, input, outpu):
  if (model.training == False) and (flag == True):
    flag = False
    # do something

But neither of them works.

Could some help me?

For ease of discussion, I include the following runnable code.

import numpy as np
import torch

from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.models import AlexNet
from torchvision.datasets import FakeData

from sklearn.metrics import accuracy_score

NUM_CLASSES = 2

LR = 1e-4
BATCH_SIZE = 32
MAX_EPOCH = 2

device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")

transform = T.ToTensor()
train_dataset = FakeData(size=800, image_size=(3, 224, 224), num_classes=NUM_CLASSES, transform=transform)
val_dataset = FakeData(size=200, image_size=(3, 224, 224), num_classes=NUM_CLASSES, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

model = AlexNet(num_classes=NUM_CLASSES).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

def hook(self, input, output):
    if model.training == False:
        print("validation...")

model.features.register_forward_hook(hook)

for epoch in range(MAX_EPOCH):
    model.train()
    print("epoch %d / %d" % ((epoch + 1), MAX_EPOCH))
    for i, (X_train, y_train) in enumerate(train_dataloader):
        X_train = X_train.type(torch.float32).to(device)
        y_train = y_train.type(torch.int64).to(device)

        optimizer.step()

        score = model(X_train)
        loss = criterion(input=score, target=y_train)
        loss.backward()

        optimizer.step()

        if (i + 1) % 10 == 0:
            print("\tloss: %.5f" % loss.item())
    
    model.eval()
    y_pred_list = list()
    y_val_list = list()
    for X_val, y_val in val_dataloader:
        X_val = X_val.type(torch.float32).to(device)

        score = model(X_val)
        y_pred_list.extend(torch.topk(score, k=1, dim=1)[1].detach().squeeze().cpu().numpy())
        y_val_list.extend(y_val)
    
    print("\tvalidation accuracy: %.5f" % accuracy_score(y_true=y_val_list, y_pred=y_pred_list))

The training and validation log is the following, which is expected since I just set if model.training == False in the hook function.

epoch 1 / 2
	loss: 0.69933
	loss: 0.69549
validation...
validation...
validation...
validation...
validation...
validation...
validation...
	validation accuracy: 0.46000
epoch 2 / 2
	loss: 0.70441
	loss: 0.69369
validation...
validation...
validation...
validation...
validation...
validation...
validation...
	validation accuracy: 0.54000

For the global flag to work, you need to define it as global as you modify it:

global flag
if flag:
  flag = False
  # Do something

Thank you for this! I have been confused about this for hours. The updated code provide the following output, which is expected

epoch 1 / 2
	loss: 0.70723
	loss: 0.68785
validation...
	validation accuracy: 0.46000
epoch 2 / 2
	loss: 0.69220
	loss: 0.71312
validation...
	validation accuracy: 0.54000

For anyone who is interested in the runnable code, here it is

import numpy as np
import torch

from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.models import AlexNet
from torchvision.datasets import FakeData

from sklearn.metrics import accuracy_score

NUM_CLASSES = 2

LR = 1e-4
BATCH_SIZE = 32
MAX_EPOCH = 2
global flag
flag = True

device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")

transform = T.ToTensor()
train_dataset = FakeData(size=800, image_size=(3, 224, 224), num_classes=NUM_CLASSES, transform=transform)
val_dataset = FakeData(size=200, image_size=(3, 224, 224), num_classes=NUM_CLASSES, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

model = AlexNet(num_classes=NUM_CLASSES).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

def hook(self, input, output):
    global flag
    if (model.training == False) and (flag == True):
        flag = False
        print("validation...")

model.features.register_forward_hook(hook)

for epoch in range(MAX_EPOCH):
    model.train()
    print("epoch %d / %d" % ((epoch + 1), MAX_EPOCH))
    for i, (X_train, y_train) in enumerate(train_dataloader):
        X_train = X_train.type(torch.float32).to(device)
        y_train = y_train.type(torch.int64).to(device)

        optimizer.step()

        score = model(X_train)
        loss = criterion(input=score, target=y_train)
        loss.backward()

        optimizer.step()

        if (i + 1) % 10 == 0:
            print("\tloss: %.5f" % loss.item())
    
    model.eval()
    y_pred_list = list()
    y_val_list = list()
    flag = True
    for X_val, y_val in val_dataloader:
        X_val = X_val.type(torch.float32).to(device)

        score = model(X_val)
        y_pred_list.extend(torch.topk(score, k=1, dim=1)[1].detach().squeeze().cpu().numpy())
        y_val_list.extend(y_val)
    
    print("\tvalidation accuracy: %.5f" % accuracy_score(y_true=y_val_list, y_pred=y_pred_list))