RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn while using pretraining BERT model

Hi, I was trying to integrate pretrained BERT model into my own custom model. Here is my custom class:

class BERTClassifier(torch.nn.Module):
    def __init__(self, bert_model, n_classes):
        super(BERTClassifier, self).__init__()
        self.bert_model = bert_model
        self.max_len = 512
        if isinstance(self.bert_model, BertModel):
            self.out_features = self.bert_model.encoder.layer[2].output.dense.out_features
        else:
            self.out_features = self.bert_model.bert.encoder.layer[2].output.dense.out_features
            
        self.dropout = torch.nn.Dropout(p=0.2)
        self.classifier = torch.nn.Linear(self.out_features, n_classes)
        self.output = torch.nn.Softmax(dim=-1)

#         for param in self.bert_model.encoder.parameters():
#             param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        bert_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = bert_output.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        probs = self.output(logits)
        return probs

There’s how I initialize class instance:

# Using BertModel from transformers library
bert_model = BertModel.from_pretrained(
        tokenizer_path,
        config=AutoConfig.from_pretrained(tokenizer_path, output_attentions=True, output_hidden_states=True))

EPOCHS = 10

my_bert = BERTClassifier(bert_model, 2).to(DEVICE)

Here is my training function:

def train_model(net, train_iter, val_iter, epochs, lr=0.02, loss=None, device=None, save_model=None,
                scheduler=None, scheduler_conf=None, use_tensorboard=False):

    if device is None:
        device = try_gpu()
    print(f"Training on {device}")

    net.to(device)

    if loss is None:
        loss = torch.nn.BCELoss()#.to(device)

    if use_tensorboard:
        writer = SummaryWriter()

    # optimizer = torch.optim.Adadelta(net.parameters(), lr=lr)
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)

    if scheduler is not None:
        if scheduler_conf is not None:
            scheduler = scheduler(optimizer, **scheduler_conf)

        else:
            scheduler = scheduler(optimizer)

    for epoch in range(epochs):
        # Set gradients to train
        net.train()
        train_loss = 0
        cnt_train = 0

        for batch in tqdm(train_iter):
            _input_ids = batch["input_ids"]
            _attention_mask = batch["attention_mask"]
            y = batch["targets"]

            optimizer.zero_grad()

            y_pred = torch.argmax(net(input_ids=_input_ids, attention_mask=_attention_mask), dim=-1).float()
            l = loss(y_pred, y)
            print(y_pred.requires_grad, y.requires_grad)
            for param in net.parameters():
                print(param.requires_grad)

            train_loss += l
            cnt_train += 1
            l.backward()

            torch.nn.utils.clip_grad_norm(net.parameters(), max_norm=1)
            optimizer.step()

        with torch.no_grad():
            val_loss = 0
            val_roc_auc = 0
            cnt_val = 0

            for val_batch in val_iter:
                # For torchmetrics cast to int
                _tmp = torch.IntTensor(y_val.cpu().numpy())
                
                val_input_ids = val_batch["input_ids"]
                val_attention_mask = val_batch["attention_mask"]
                y_val = val_batch["targets"]

                y_val_pred = torch.argmax(net(val_input_ids, val_attention_mask), dim=-1).float()

                val_loss += loss(y_val_pred, y_val)
                val_roc_auc += torchmetrics.functional.auroc(y_val_pred, _tmp).item()
                cnt_val += 1

            if scheduler is not None:
                scheduler.step()

            if use_tensorboard:
                writer.add_scalar("train_loss", l, epoch)
                writer.add_scalar("val_loss", val_loss / cnt_val, epoch)
                writer.add_scalar("val_roc_auc", val_roc_auc / cnt_val, epoch)

        print(f"epoch: {epoch}", f"train loss: {(train_loss / cnt_train):.5f}",
              f"val loss: {val_loss / cnt_val:.5f} val ROC AUC: {val_roc_auc / cnt_val:.5f}",
              f"lr: {optimizer.param_groups[0]['lr']}"
             )

    if save_model is not None:
        torch.save(net.state_dict(), save_model)
        print(f"Model saved to {save_model}")

    return

I also checked all the parameters in my class with:

for param in my_bert.parameters():
   print(param.requires_grad)

All of them are True. I don’t understand what’s the problem

torch.argmax is not differentiable and will thus break the computation graph.
Also, nn.BCELoss expects probabilities as the model output and is used for binary or multi-label classification use cases so you should use a sigmoid in this case.

Thank you @ptrblck that solved the issue!