Thanks for your reply, here is the specific loss function that is using a lot of memory and the distillation loss function (that calls the loss functions). May I ask is it okay practice to put the predictions and auxiliary stuff into a dictionary and allow the loss function to pick out what it needs?
def kd_s_loss(pred, target, device, weight=0.25):
student_logits = pred['student_out']['ssc_logits'].to(device)
# Remove grad from teacher logits
teacher_logits = pred['teacher_out']['ssc_logits'].detach().to(device)
student_logits = student_logits.reshape(student_logits.shape[0], student_logits.shape[1], -1)
teacher_logits = teacher_logits.reshape(teacher_logits.shape[0], teacher_logits.shape[1], -1)
target = target['target'].reshape(student_logits.shape[0], student_logits.shape[-1]).to(device)
T = 1
loss_center, loss_sim = 0., 0.
for y_s, y_t, label in zip(student_logits, teacher_logits, target):
y_s, y_t = y_s.flatten(1).permute(1, 0), y_t.flatten(1).permute(1, 0)
label = label
# category centers
unique_label = label.unique()
unique_label = unique_label[unique_label != 0]
mask = label[:, None] == unique_label[None, :]
y_t_center = (y_t[:, None, :] * mask[:, :, None]).sum(0) / mask.sum(0)[:, None]
y_s_center = (y_s[:, None, :] * mask[:, :, None]).sum(0) / mask.sum(0)[:, None]
# KLloss for category centers
p_s = F.log_softmax(y_s_center / T, dim=1)
p_t = F.softmax(y_t_center / T, dim=1)
loss_center = loss_center + (F.kl_div(p_s, p_t, reduction='none') * (T ** 2)).sum(-1).mean()
# MSE loss for relation with category centers
sim_t = torch.cosine_similarity(y_t[:, None], y_t_center[None, :], dim=-1)
sim_s = torch.cosine_similarity(y_s[:, None], y_s_center[None, :], dim=-1)
mseloss = nn.MSELoss()
loss_sim = loss_sim + mseloss(sim_t, sim_s)
loss = (loss_center + loss_sim) / target.shape[0]
return loss * weight
Distillation module has loss method that calls the kd_s loss:
def loss(self, preds, target):
target['class_weights'] = self.class_weights.type(torch.float32)
loss_dict = {}
if self.bundle_loss_to_device:
# Convert the target to the correct device
target = {k: v.to(self.criterion_device) for k, v in target.items()}
# Convert the preds to the correct device (which is a nested dict)
preds = {
k: {kk: vv.to(self.criterion_device) for kk, vv in v.items()}
for k, v in preds.items()
}
# Collect the teacher and student losses **not important**
for model_name, model in zip(["student", "teacher"], [self.student_model, self.teacher_model]):
if model is None:
continue
# Get the correct device for the model
model_device = self.criterion_device if self.bundle_loss_to_device else next(model.parameters()).device
# Copy the target dict
model_target = {k: v.to(model_device).detach() for k, v in target.items()}
# Calculate the loss
model_loss_dict = model.loss(preds.get(f"{model_name}_out", None), model_target)
loss_dict.update({
f"{model_name}_{loss_key}": loss_value
for loss_key, loss_value in model_loss_dict.items()
})
distillation_loss_map = {
'pvi': pvi_loss,
'kd_s': kd_s_loss,
'kd_t': kd_t_loss,
'kd_t_mask': kd_t_mask_loss
}
# Calculate the distillation loss **<- distillation losses here **
if self.criterions is not None and len(self.criterions) > 0:
assert self.student_model is not None and self.teacher_model is not None, \
"Both models must be provided to use distillation criterions!"
loss_dict.update({
f"distill_{criterion}": distillation_loss_map[criterion](preds, target, self.criterion_device)
for criterion in self.criterions
})
# Move the loss values to the correct device
loss_dict = {k: v.to(self.criterion_device) for k, v in loss_dict.items()}
return loss_dict
After this, the loss_dict values are summed and backward is performed.