Moving memory expensive losses to different device

Hi, I am currently working on a knowledge distillation problem with two large models. One model is frozen, and the other is trainable. I am a student and have been allocated 2x12GB cuda gpus.

Training is possible with both models on a single gpu (batch size 1) and an MSE loss. But when trying out a more expensive custom loss with KL divergence (similar to this link) I instantly run out of memory. The loss tries to allocate > 2GB of memory.

I have tried moving the models to separate GPUs but I run out of memory. I also have tried keeping both models on one device, and move the preds and targets to the other gpu before computing the losses, but I still run out of memory. In both cases (with the losses moved to cpu), I have debugged and can see that the computed grads are on the correct devices.

Overall, is there any particular reason why this strategy wouldn’t work? Perhaps there is some inefficiency during the backwards pass?

Could you post the custom loss function? From what I’ve seen it’s easy to accidentally keep data in memory that should be detached. (see FAQ Memory Topic)

Thanks for your reply, here is the specific loss function that is using a lot of memory and the distillation loss function (that calls the loss functions). May I ask is it okay practice to put the predictions and auxiliary stuff into a dictionary and allow the loss function to pick out what it needs?

def kd_s_loss(pred, target, device, weight=0.25):
    student_logits = pred['student_out']['ssc_logits'].to(device)
    # Remove grad from teacher logits
    teacher_logits = pred['teacher_out']['ssc_logits'].detach().to(device)
    student_logits = student_logits.reshape(student_logits.shape[0], student_logits.shape[1], -1)
    teacher_logits = teacher_logits.reshape(teacher_logits.shape[0], teacher_logits.shape[1], -1)
    target = target['target'].reshape(student_logits.shape[0], student_logits.shape[-1]).to(device)
    T = 1
    loss_center, loss_sim = 0., 0.
    for y_s, y_t, label in zip(student_logits, teacher_logits, target):
        y_s, y_t = y_s.flatten(1).permute(1, 0), y_t.flatten(1).permute(1, 0)
        label = label

        # category centers
        unique_label = label.unique()
        unique_label = unique_label[unique_label != 0]
        mask = label[:, None] == unique_label[None, :]
        y_t_center = (y_t[:, None, :] * mask[:, :, None]).sum(0) / mask.sum(0)[:, None]
        y_s_center = (y_s[:, None, :] * mask[:, :, None]).sum(0) / mask.sum(0)[:, None]

        # KLloss for category centers
        p_s = F.log_softmax(y_s_center / T, dim=1)
        p_t = F.softmax(y_t_center / T, dim=1)
        loss_center = loss_center + (F.kl_div(p_s, p_t, reduction='none') * (T ** 2)).sum(-1).mean()

        # MSE loss for relation with category centers
        sim_t = torch.cosine_similarity(y_t[:, None], y_t_center[None, :], dim=-1)
        sim_s = torch.cosine_similarity(y_s[:, None], y_s_center[None, :], dim=-1)
        mseloss = nn.MSELoss()
        loss_sim = loss_sim + mseloss(sim_t, sim_s)
    loss = (loss_center + loss_sim) / target.shape[0]
    return loss * weight

Distillation module has loss method that calls the kd_s loss:

def loss(self, preds, target):
        target['class_weights'] = self.class_weights.type(torch.float32)
        loss_dict = {}

        if self.bundle_loss_to_device:
            # Convert the target to the correct device
            target = {k: v.to(self.criterion_device) for k, v in target.items()}
            # Convert the preds to the correct device (which is a nested dict)
            preds = {
                k: {kk: vv.to(self.criterion_device) for kk, vv in v.items()}
                for k, v in preds.items()
            }

        # Collect the teacher and student losses **not important**
        for model_name, model in zip(["student", "teacher"], [self.student_model, self.teacher_model]):
            if model is None:
                continue

            # Get the correct device for the model
            model_device = self.criterion_device if self.bundle_loss_to_device else next(model.parameters()).device

            # Copy the target dict
            model_target = {k: v.to(model_device).detach() for k, v in target.items()}
            # Calculate the loss
            model_loss_dict = model.loss(preds.get(f"{model_name}_out", None), model_target)
            loss_dict.update({
                f"{model_name}_{loss_key}": loss_value
                for loss_key, loss_value in model_loss_dict.items()
            })

        distillation_loss_map = {
            'pvi': pvi_loss,
            'kd_s': kd_s_loss,
            'kd_t': kd_t_loss,
            'kd_t_mask': kd_t_mask_loss
        }

        # Calculate the distillation loss **<- distillation losses here **
        if self.criterions is not None and len(self.criterions) > 0:
            assert self.student_model is not None and self.teacher_model is not None, \
                "Both models must be provided to use distillation criterions!"
            loss_dict.update({
                f"distill_{criterion}": distillation_loss_map[criterion](preds, target, self.criterion_device)
                for criterion in self.criterions
            })

        # Move the loss values to the correct device
        loss_dict = {k: v.to(self.criterion_device) for k, v in loss_dict.items()}

        return loss_dict

After this, the loss_dict values are summed and backward is performed.