Loss.backward() causing unexpected errors

I am trying to train a BERT model with 3 output heads on 3 tasks. When I begin the training I run into the issue ‘RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn’. I have set most of the parameters to requires_grad=False because for the first epoch I only want to update the output layers. This approach has worked well with 2 tasks/heads, and when all I change is adding the third head/task I get this issue. Nothing else is being changed.
Here is where the error occurs:

def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> List[torch.Tensor]:
        model.train()
        if self.l0_module is not None:
            self.l0_module.train()
        inputs = self._prepare_inputs(inputs)

        distill_loss = None
        distill_ce_loss = None
        # if self.teacher_model is not None:
        with torch.no_grad():
            # only retain inputs of certain keys
            teacher_inputs_keys = ["input_ids", "attention_mask", "token_type_ids", "position_ids", "labels",
                                    "output_attentions", "output_hidden_states", "return_dict"]
            teacher_inputs = {key: inputs[key]
                                for key in teacher_inputs_keys if key in inputs}
            self.shortens_inputs(teacher_inputs)
            teacher_outputs = self.teacher_model(**teacher_inputs)
        self.shortens_inputs(inputs)
        student_outputs = model(**inputs) #! get the two outputs

        zs = {key: inputs[key] for key in inputs if "_z" in key} #! extract the zs
        distill_loss, distill_ce_loss, loss = self.calculate_distillation_loss(
            teacher_outputs, student_outputs, zs)

        lagrangian_loss = None
    
        loss.backward()

Here is the model init:

def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.bert = CoFiBertModel(config)

        self.do_layer_distill = getattr(config, "do_layer_distill", False)
        self.num_labels = 2
        self.task = None

        self.dropout = nn.Dropout(p=0.1, inplace=False)

        self.task1_classifier = nn.Linear(config.hidden_size, 2)
        self.task2_classifier = nn.Linear(config.hidden_size, 2)
        self.task3_classifier = nn.Linear(config.hidden_size, 2)

        self.post_init()
        if self.do_layer_distill:
            self.layer_transformation = nn.Linear(
                config.hidden_size, config.hidden_size)
        else:
            self.layer_transformation = None