RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn while using pretraining BERT model

Hi, I was trying to integrate pretrained BERT model into my own custom model. Here is my custom class:

class BERTClassifier(torch.nn.Module):
    def __init__(self, bert_model, n_classes):
        super(BERTClassifier, self).__init__()
        self.bert_model = bert_model
        self.max_len = 512
        if isinstance(self.bert_model, BertModel):
            self.out_features = self.bert_model.encoder.layer[2].output.dense.out_features
        else:
            self.out_features = self.bert_model.bert.encoder.layer[2].output.dense.out_features
            
        self.dropout = torch.nn.Dropout(p=0.2)
        self.classifier = torch.nn.Linear(self.out_features, n_classes)
        self.output = torch.nn.Softmax(dim=-1)

#         for param in self.bert_model.encoder.parameters():
#             param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        bert_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = bert_output.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        probs = self.output(logits)
        return probs

There’s how I initialize class instance:

# Using BertModel from transformers library
bert_model = BertModel.from_pretrained(
        tokenizer_path,
        config=AutoConfig.from_pretrained(tokenizer_path, output_attentions=True, output_hidden_states=True))

EPOCHS = 10

my_bert = BERTClassifier(bert_model, 2).to(DEVICE)

Here is my training function:

def train_model(net, train_iter, val_iter, epochs, lr=0.02, loss=None, device=None, save_model=None,
                scheduler=None, scheduler_conf=None, use_tensorboard=False):

    if device is None:
        device = try_gpu()
    print(f"Training on {device}")

    net.to(device)

    if loss is None:
        loss = torch.nn.BCELoss()#.to(device)

    if use_tensorboard:
        writer = SummaryWriter()

    # optimizer = torch.optim.Adadelta(net.parameters(), lr=lr)
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)

    if scheduler is not None:
        if scheduler_conf is not None:
            scheduler = scheduler(optimizer, **scheduler_conf)

        else:
            scheduler = scheduler(optimizer)

    for epoch in range(epochs):
        # Set gradients to train
        net.train()
        train_loss = 0
        cnt_train = 0

        for batch in tqdm(train_iter):
            _input_ids = batch["input_ids"]
            _attention_mask = batch["attention_mask"]
            y = batch["targets"]

            optimizer.zero_grad()

            y_pred = torch.argmax(net(input_ids=_input_ids, attention_mask=_attention_mask), dim=-1).float()
            l = loss(y_pred, y)
            print(y_pred.requires_grad, y.requires_grad)
            for param in net.parameters():
                print(param.requires_grad)

            train_loss += l
            cnt_train += 1
            l.backward()

            torch.nn.utils.clip_grad_norm(net.parameters(), max_norm=1)
            optimizer.step()

        with torch.no_grad():
            val_loss = 0
            val_roc_auc = 0
            cnt_val = 0

            for val_batch in val_iter:
                # For torchmetrics cast to int
                _tmp = torch.IntTensor(y_val.cpu().numpy())
                
                val_input_ids = val_batch["input_ids"]
                val_attention_mask = val_batch["attention_mask"]
                y_val = val_batch["targets"]

                y_val_pred = torch.argmax(net(val_input_ids, val_attention_mask), dim=-1).float()

                val_loss += loss(y_val_pred, y_val)
                val_roc_auc += torchmetrics.functional.auroc(y_val_pred, _tmp).item()
                cnt_val += 1

            if scheduler is not None:
                scheduler.step()

            if use_tensorboard:
                writer.add_scalar("train_loss", l, epoch)
                writer.add_scalar("val_loss", val_loss / cnt_val, epoch)
                writer.add_scalar("val_roc_auc", val_roc_auc / cnt_val, epoch)

        print(f"epoch: {epoch}", f"train loss: {(train_loss / cnt_train):.5f}",
              f"val loss: {val_loss / cnt_val:.5f} val ROC AUC: {val_roc_auc / cnt_val:.5f}",
              f"lr: {optimizer.param_groups[0]['lr']}"
             )

    if save_model is not None:
        torch.save(net.state_dict(), save_model)
        print(f"Model saved to {save_model}")

    return

I also checked all the parameters in my class with:

for param in my_bert.parameters():
   print(param.requires_grad)

All of them are True. I don’t understand what’s the problem

torch.argmax is not differentiable and will thus break the computation graph.
Also, nn.BCELoss expects probabilities as the model output and is used for binary or multi-label classification use cases so you should use a sigmoid in this case.

Thank you @ptrblck that solved the issue!

Thank you @ptrblck for this hint. I have a similar problem, I think. Could you give me a hint at how to resolve the issue?

I have a binary classification with example output

tensor([0.0213, 0.1017], requires_grad=True)

From this output, I infer the predicted class using

result = model(input_tensors).argmax(dim=-1)

There is no problem in training and validation, but now that I’m trying to implement Deeplift with Captum, I get a similar error as mentioned in the thread.
Can you tell me how to work around the issue? I still have to use some equivalent to torch.argmax to be able to get a single class prediction, don’t I?

If desired, I can open up a new topic with some more explanation.

torch.argmax is not differentiable as mentioned in my previous post so you won’t be able to work around it.
I don’t know what Deeplift is so could you explain it and why class predictions would be needed?

Deeplift is one of the “primary attribution” algorithms that come with Captum for model explainability: https://captum.ai/api/deep_lift.html

I want to visualize attributions of my input time series with respect to the predicted class (i.e., why does the model pick which class) as done in this example

Therefore, afaik I need to provide the target class to let the algorithm know with respect to which class the attributions should be computed.

The target tensor can and should contain class indices and I’m unsure if you have to pass the target tensor as additional information to Deeplift or where these targets should come from.
The linked example also uses the model outputs directly without creating the predicted class labels.