I’m trying to use the value of layer.register_forward_hook
to train my network.
During training, if I use the value returned from network to train my network, the training goes well.
(I’ve checked the gradients, loss decreasing and accuracy increasing)
But if I use the value from layer.register_forward_hook
to train my network, every gradients are zero and loss and accuracy is not changed.
I really don’t know what is the correct way to use the values from layer.register_forward_hook
.
I want to know the method to train network with values from register_forward_hook
function.
Please help me guys.
My entire code is below:
'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from models import *
from utils import progress_bar
import sys
# ====================Register hooking functions====================
global glb_feature
def Get_features(self, input, output):
global glb_feature
glb_feature = output.data
return None
global glb_grad
def Get_grad(self, ingrad, outgrad):
global glb_grad
glb_grad = outgrad
return None
#============================================================
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
batch_size = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
# Dataset preparing
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Model
print('==> Building model..')
net = StudentNet()
net = net.to(device)
final_layer = net.classifier2
#====================define glb_feature and glb_grad====================
glb_feature = torch.tensor(torch.zeros(batch_size, len(classes))
, requires_grad=True, device=torch.device(device))
glb_grad = torch.tensor(torch.zeros(batch_size, len(classes))
, requires_grad=True, device=torch.device(device))
#============================================================
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# Training
def train(epoch):
global glb_feature
global glb_grad
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
# ====================REGISTER_FORWARD_HOOOK====================
final_layer.register_forward_hook(Get_features)
final_tensor = torch.tensor(glb_feature, requires_grad=True
, device=torch.device(device))
#============================================================
loss = criterion(final_tensor, targets) # the way I want to implement
#loss = criterion(outputs, targets) # original code
loss.backward()
optimizer.step()
#====================REGISTER_BACKWARD_HOOK====================
final_layer.register_backward_hook(Get_grad)
print(glb_grad)
#============================================================
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
# main
for epoch in range(start_epoch, start_epoch+50):
train(epoch)