Here is my full code.
# Register hooking function
global glb_feature_teacher
global glb_feature_student
def Get_features4teacher(self, input, output):
global glb_feature_teacher
glb_feature_teacher = output.data
return None
# end
def Get_features4student(self, input, output):
global glb_feature_student
glb_feature_student = output.data
return None
# end
# Parsers
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_batch = 64
num_emb = 128
trainset = torchvision.datasets.CIFAR10(root='../../data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='../../data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=train_batch, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Model
print('==> Building model..')
teacher_net = ResNet50()
student_net = StudentNet()
teacher_net = teacher_net.to(device)
student_net = student_net.to(device)
if device == 'cuda':
teacher_net = torch.nn.DataParallel(teacher_net)
student_net = torch.nn.DataParallel(student_net)
cudnn.benchmark = True
print('Loading teacher, student network weight file')
try:
checkpoint_teacher = torch.load('./resnet50.t7')
teacher_net.load_state_dict(checkpoint_teacher['net'])
except FileNotFoundError:
print('ERROR::No pretrained teacher network file found!')
sys.exit(1)
t_emb_layer = teacher_net.module.linear1
s_emb_layer = student_net.module.classifier1
'''=============================parameter settings=========================='''
for param in student_net.parameters():
param.requires_grad=True
for param in teacher_net.parameters():
param.requires_grad=False
'''==============================LOSS FUNCTION LOCATION=============================='''
mse_loss = nn.MSELoss()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(student_net.parameters(), lr=args.lr, momentum=0.9 ,weight_decay=5e-4)
glb_feature_teacher = torch.tensor(torch.zeros(train_batch, num_emb), requires_grad=False, device=torch.device(device))
glb_feature_student = torch.tensor(torch.zeros(train_batch, num_emb), requires_grad=True, device=torch.device(device))
def train(epoch):
global glb_feature_teacher
global glb_feature_student
print('\nEpoch: %d' % epoch)
student_net.train()
teacher_net.eval()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs_teacher = teacher_net(inputs)
outputs_student = student_net(inputs)
'''============================================================================'''
t_emb_layer.register_forward_hook(Get_features4teacher)
s_emb_layer.register_forward_hook(Get_features4student)
emb_teacher = torch.tensor(glb_feature_teacher, requires_grad=False, device=torch.device(device))
emb_student = torch.tensor(glb_feature_student, requires_grad=True, device=torch.device(device))
loss_c = criterion(outputs_student, targets)
loss_v = mse_loss(emb_student, emb_teacher)
loss = loss_c + 0.1*loss_v
loss.backward()
optimizer.step()
torch.cuda.synchronize()
'''==========================GRADIENT CHECKING================================='''
grad_of_params_student = {}
for name, parameter in student_net.named_parameters():
grad_of_params_student[name] = parameter.grad
#print(name, parameter.grad)
#print('checking student: ', parameter.size())
grad_of_params_teacher = {}
for name, parameter in teacher_net.named_parameters():
grad_of_params_teacher[name] = parameter.grad
#print('checking teacher: ', parameter.size())
print('student: ', grad_of_params_student['module.classifier1.weight']) # for student net
print('teacher: ', grad_of_params_teacher['module.linear1.weight']) # for teacher net
'''============================================================================'''
train_loss += loss.item()
_, predicted = outputs_student.max(1) #max(1): second value returns argmax
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
and the main function is below:
for epoch in range(start_epoch, start_epoch+100):
train(epoch)
The code does not calculate gradient when only loss_v is applied.
How can i fix this bug?