Knowledge Distillation training code not working properly

Hello everyone.
I’m doing a little research project regarding CNN’s compression methods and I’m working on Knowledge Distillation. I wrote this 2 functions to train a simple CNN with and without KD (I tried to follow this KNOWLEDGE DISTILLATION TUTORIAL):

def train_one_epoch(model,optimizer):
    running_loss = 0.
    last_loss = 0.
    for i, data in enumerate(training_loader):
        print("Batch numero " + str(i)+ " di " + str(int(len(training_set)/16)))
        inputs, labels = data
        inputs, labels = inputs.cuda(),labels.cuda()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    return last_loss

def train_one_epoch_distillation(teacher,student,optimizer):
    running_loss = 0.
    last_loss = 0.
    teacher.eval()
    student.train()
    ce_loss = nn.CrossEntropyLoss()
    kl_loss = nn.KLDivLoss()
    T = 2
    soft_target_loss_weight = 0.25
    ce_loss_weight = 0.75
    for i,data in enumerate(training_loader):
        print("Batch numero " + str(i)+ " di " + str(int(len(training_set)/16)))
        inputs, labels = data
        inputst = inputs
        resized_tensors = []
        batch_size, depth, height, width = inputst.size()
        for i in range(batch_size):
            tensor = inputst[i]
            resized_tensor = torch.nn.functional.interpolate(tensor.unsqueeze(0).unsqueeze(0), size=(3, 512, 512), mode='trilinear', align_corners=False)
            resized_tensors.append(resized_tensor.squeeze(0).squeeze(0))

        resized_batch = torch.stack(resized_tensors)
        inputst=resized_batch
        inputs, labels = inputs.cuda(),labels.cuda()
        inputst = inputst.cuda()
        optimizer.zero_grad()
        
        with torch.no_grad():
            teacher_logits = teacher(inputst)
            
        student_logits = student(inputs)
        soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
        soft_targets_loss = kl_loss(student_logits, soft_targets) * (T**2)
        label_loss = ce_loss(student_logits, labels)
        loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

I run these functions with the following code:

nn_light = LightNN()

nn_light.to(torch.device("cuda"))
optimizer = torch.optim.SGD(nn_light.parameters(),lr=0.001, momentum=0.9)

nn_light.train()
for i in range(5):
    train_one_epoch(nn_light,optimizer)

torch.save(nn_light.state_dict(), "modelli/lw_nn")

PATH = "modelli\model_20230811_114742_4"
teacher = torchvision.models.resnet18()
teacher.load_state_dict(torch.load(PATH))
student = LightNN()
student.train()
teacher.eval()
teacher.to(torch.device("cuda"))
student.to(torch.device("cuda"))
optimizer = torch.optim.SGD(student.parameters(),lr=0.001,momentum=0.9)
for i in range(5):
    train_one_epoch_distillation(teacher,student,optimizer)
torch.save(student.state_dict(), "modelli/lw_distillata")

When I run this code using the Food101 dataset, I get sensible results from the normal training, and terrible results from the KD training. The weird thing is that even if I set soft_target_loss_weight to 0 and ce_loss_weight to 1, which should give me the same results as the normal training, I get the same non sensical results as before (same thing happen if I try to make loss = label_loss).
I think I might have made some dumb mistake in the code, but I’ve been looking for a while now and I, being kind of new to PyTorch, can’t seem to be able to spot it. Anyone has any idea on how to fix this?