Hello everyone.
I’m doing a little research project regarding CNN’s compression methods and I’m working on Knowledge Distillation. I wrote this 2 functions to train a simple CNN with and without KD (I tried to follow this KNOWLEDGE DISTILLATION TUTORIAL):
def train_one_epoch(model,optimizer):
running_loss = 0.
last_loss = 0.
for i, data in enumerate(training_loader):
print("Batch numero " + str(i)+ " di " + str(int(len(training_set)/16)))
inputs, labels = data
inputs, labels = inputs.cuda(),labels.cuda()
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
running_loss = 0.
return last_loss
def train_one_epoch_distillation(teacher,student,optimizer):
running_loss = 0.
last_loss = 0.
teacher.eval()
student.train()
ce_loss = nn.CrossEntropyLoss()
kl_loss = nn.KLDivLoss()
T = 2
soft_target_loss_weight = 0.25
ce_loss_weight = 0.75
for i,data in enumerate(training_loader):
print("Batch numero " + str(i)+ " di " + str(int(len(training_set)/16)))
inputs, labels = data
inputst = inputs
resized_tensors = []
batch_size, depth, height, width = inputst.size()
for i in range(batch_size):
tensor = inputst[i]
resized_tensor = torch.nn.functional.interpolate(tensor.unsqueeze(0).unsqueeze(0), size=(3, 512, 512), mode='trilinear', align_corners=False)
resized_tensors.append(resized_tensor.squeeze(0).squeeze(0))
resized_batch = torch.stack(resized_tensors)
inputst=resized_batch
inputs, labels = inputs.cuda(),labels.cuda()
inputst = inputst.cuda()
optimizer.zero_grad()
with torch.no_grad():
teacher_logits = teacher(inputst)
student_logits = student(inputs)
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_targets_loss = kl_loss(student_logits, soft_targets) * (T**2)
label_loss = ce_loss(student_logits, labels)
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
running_loss = 0.
I run these functions with the following code:
nn_light = LightNN()
nn_light.to(torch.device("cuda"))
optimizer = torch.optim.SGD(nn_light.parameters(),lr=0.001, momentum=0.9)
nn_light.train()
for i in range(5):
train_one_epoch(nn_light,optimizer)
torch.save(nn_light.state_dict(), "modelli/lw_nn")
PATH = "modelli\model_20230811_114742_4"
teacher = torchvision.models.resnet18()
teacher.load_state_dict(torch.load(PATH))
student = LightNN()
student.train()
teacher.eval()
teacher.to(torch.device("cuda"))
student.to(torch.device("cuda"))
optimizer = torch.optim.SGD(student.parameters(),lr=0.001,momentum=0.9)
for i in range(5):
train_one_epoch_distillation(teacher,student,optimizer)
torch.save(student.state_dict(), "modelli/lw_distillata")
When I run this code using the Food101 dataset, I get sensible results from the normal training, and terrible results from the KD training. The weird thing is that even if I set soft_target_loss_weight to 0 and ce_loss_weight to 1, which should give me the same results as the normal training, I get the same non sensical results as before (same thing happen if I try to make loss = label_loss).
I think I might have made some dumb mistake in the code, but I’ve been looking for a while now and I, being kind of new to PyTorch, can’t seem to be able to spot it. Anyone has any idea on how to fix this?