RuntimeError: shape '[-1, 3]' is invalid for input of size 16384

I want to finetune “cross-encoder/nli-distilroberta-base” this model to my own binary classification task, the original pretrained model is a 3 classes model, so I need to change it to 2 classes in the output, I do not know what causes this error, plz help me!!! And this error occurs when the validation sanity check starts.

### Define the LightningModule
class KPM(pl.LightningModule):
    def __init__(self, model, learning_rate=2e-5, weight_decay=0.001):
        super(KPM, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        return outputs

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
        outputs = self(input_ids, attention_mask, labels)
        print("--------------training-------------------------------")
        print("Shape: "+ str(outputs.logits.shape))
        print(outputs.logits)
        # self.log("train_loss", torch.clone(outputs.loss).detach())
        # one_hot_labels = torch.stack([1 - labels, labels], dim=-1)
        # one_hot_labels = one_hot_labels.float()
        loss = outputs.loss
        # loss = F.binary_cross_entropy_with_logits(outputs.logits, one_hot_labels, reduction='mean')
        self.log("train_loss", torch.clone(loss).detach())
        return loss

    # def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
    #     input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
    #     print(len(input_ids))
    #     # inputs = batch[:-1]  # Assuming the input tensor is at the first index
    #     # print(f"Validation Batch {batch_idx + 1}, Input Tensor Size: {inputs[0].size()}")
    #     # print(inputs[0])  # Assuming the input tensor is at index 0

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
        outputs = self(input_ids, attention_mask, labels)
        self.log("val_loss", torch.clone(outputs.loss).detach())
        # one_hot_labels = torch.stack([1 - labels, labels], dim=-1)
        # one_hot_labels = one_hot_labels.float()
        print(labels)
        loss = outputs.loss
        # loss = F.binary_cross_entropy_with_logits(outputs.logits, one_hot_labels, reduction='mean')
        self.log("val_loss", torch.clone(loss).detach())
        return  loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer
### Training phase
model_name = "cross-encoder/nli-distilroberta-base"
num_classes = 2
max_length = 512
batch_size = 16
learning_rate = 5e-05
weight_decay = 0.001

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# model = AutoModelForSequenceClassification.from_pretrained(model_name)
# model.classifier = torch.nn.Linear(model.config.hidden_size, num_classes)
# Change the loss function to binary cross-entropy
# print(model.type)
# print(model.config)
# Modify the classifier for binary classification
model.classifier = torch.nn.Linear(model.config.hidden_size, num_classes)

# Change the loss function to binary cross-entropy
model.config.num_labels = num_classes
model.config.id2label = {i: str(i) for i in range(num_classes)}
model.config.label2id = {str(i): i for i in range(num_classes)}
print(model.config)

# Check if the "CLS" token is present in the vocabulary
vocab = tokenizer.get_vocab()
cls_token_present = tokenizer.cls_token in vocab
sep_token_present = tokenizer.sep_token in vocab
print("Is CLS token present in the vocabulary?", cls_token_present)
print("Is SEP token present in the vocabulary?", sep_token_present)

train_dataset = Dataset(train_df, tokenizer, max_length)
val_dataset = Dataset(val_df, tokenizer, max_length)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Assuming you have a DataLoader named val_loader
for batch in val_loader:
    inputs = batch['input_ids']  # Replace 'input_ids' with the key used in your dataset for input tokens
    labels = batch['labels']  # Replace 'labels' with the key used in your dataset for labels

    # Print or inspect the shape of input and labels
    print(f"Input shape: {inputs.shape}")
    print(f"Labels shape: {labels.shape}")

    # Optionally, print the actual input and labels values
    print("Input values:", inputs)
    print("Labels values:", labels)

    break  # Stop after inspecting the first batch


model = KPM(model, learning_rate)

# Define a ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoint',
    filename='nli_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min'
)

early_stopping =EarlyStopping(
    monitor="val_loss",
    min_delta=0.01,
    patience=3
)

# define trainer
trainer = Trainer(
    min_epochs = 0, # change this
    max_epochs = 20,
    callbacks=[checkpoint_callback, early_stopping],
    accelerator="auto",
    #progress_bar_refresh_rate=30,
    #gpus = 1 if device.type == 'cuda' else 0
    devices = 1 if torch.cuda.is_available() else None
)

torch.cuda.empty_cache()

# start training
trainer.fit(model, train_loader, val_loader)
print("Training is finished")
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ C:\Users\jhluo\AppData\Local\Temp\ipykernel_8180\4216769630.py:84 in <cell line: 84>             │
│                                                                                                  │
│ [Errno 2] No such file or directory:                                                             │
│ 'C:\\Users\\jhluo\\AppData\\Local\\Temp\\ipykernel_8180\\4216769630.py'                          │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py:740 in fit                   │
│                                                                                                  │
│    737 │   │   │   │   " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"          │
│    738 │   │   │   )                                                                             │
│    739 │   │   │   train_dataloaders = train_dataloader                                          │
│ ❱  740 │   │   self._call_and_handle_interrupt(                                                  │
│    741 │   │   │   self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_  │
│    742 │   │   )                                                                                 │
│    743                                                                                           │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py:685 in                       │
│ _call_and_handle_interrupt                                                                       │
│                                                                                                  │
│    682 │   │   │   **kwargs: keyword arguments to be passed to `trainer_fn`                      │
│    683 │   │   """                                                                               │
│    684 │   │   try:                                                                              │
│ ❱  685 │   │   │   return trainer_fn(*args, **kwargs)                                            │
│    686 │   │   # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7  │
│    687 │   │   except KeyboardInterrupt as exception:                                            │
│    688 │   │   │   rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown..."  │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py:777 in _fit_impl             │
│                                                                                                  │
│    774 │   │                                                                                     │
│    775 │   │   # TODO: ckpt_path only in v1.7                                                    │
│    776 │   │   ckpt_path = ckpt_path or self.resume_from_checkpoint                              │
│ ❱  777 │   │   self._run(model, ckpt_path=ckpt_path)                                             │
│    778 │   │                                                                                     │
│    779 │   │   assert self.state.stopped                                                         │
│    780 │   │   self.training = False                                                             │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py:1199 in _run                 │
│                                                                                                  │
│   1196 │   │   self.checkpoint_connector.resume_end()                                            │
│   1197 │   │                                                                                     │
│   1198 │   │   # dispatch `start_training` or `start_evaluating` or `start_predicting`           │
│ ❱ 1199 │   │   self._dispatch()                                                                  │
│   1200 │   │                                                                                     │
│   1201 │   │   # plugin will finalized fitting (e.g. ddp_spawn will load trained model)          │
│   1202 │   │   self._post_dispatch()                                                             │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py:1279 in _dispatch            │
│                                                                                                  │
│   1276 │   │   elif self.predicting:                                                             │
│   1277 │   │   │   self.training_type_plugin.start_predicting(self)                              │
│   1278 │   │   else:                                                                             │
│ ❱ 1279 │   │   │   self.training_type_plugin.start_training(self)                                │
│   1280 │                                                                                         │
│   1281 │   def run_stage(self):                                                                  │
│   1282 │   │   self.accelerator.dispatch(self)                                                   │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py:2 │
│ 02 in start_training                                                                             │
│                                                                                                  │
│   199 │                                                                                          │
│   200 │   def start_training(self, trainer: "pl.Trainer") -> None:                               │
│   201 │   │   # double dispatch to initiate the training loop                                    │
│ ❱ 202 │   │   self._results = trainer.run_stage()                                                │
│   203 │                                                                                          │
│   204 │   def start_evaluating(self, trainer: "pl.Trainer") -> None:                             │
│   205 │   │   # double dispatch to initiate the test loop                                        │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py:1289 in run_stage            │
│                                                                                                  │
│   1286 │   │   │   return self._run_evaluate()                                                   │
│   1287 │   │   if self.predicting:                                                               │
│   1288 │   │   │   return self._run_predict()                                                    │
│ ❱ 1289 │   │   return self._run_train()                                                          │
│   1290 │                                                                                         │
│   1291 │   def _pre_training_routine(self):                                                      │
│   1292 │   │   # wait for all to join if on distributed                                          │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py:1311 in _run_train           │
│                                                                                                  │
│   1308 │   │   if not self.is_global_zero and self.progress_bar_callback is not None:            │
│   1309 │   │   │   self.progress_bar_callback.disable()                                          │
│   1310 │   │                                                                                     │
│ ❱ 1311 │   │   self._run_sanity_check(self.lightning_module)                                     │
│   1312 │   │                                                                                     │
│   1313 │   │   # enable train mode                                                               │
│   1314 │   │   self.model.train()                                                                │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py:1375 in _run_sanity_check    │
│                                                                                                  │
│   1372 │   │   │                                                                                 │
│   1373 │   │   │   # run eval step                                                               │
│   1374 │   │   │   with torch.no_grad():                                                         │
│ ❱ 1375 │   │   │   │   self._evaluation_loop.run()                                               │
│   1376 │   │   │                                                                                 │
│   1377 │   │   │   self.call_hook("on_sanity_check_end")                                         │
│   1378                                                                                           │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\loops\base.py:145 in run                        │
│                                                                                                  │
│   142 │   │   while not self.done:                                                               │
│   143 │   │   │   try:                                                                           │
│   144 │   │   │   │   self.on_advance_start(*args, **kwargs)                                     │
│ ❱ 145 │   │   │   │   self.advance(*args, **kwargs)                                              │
│   146 │   │   │   │   self.on_advance_end()                                                      │
│   147 │   │   │   │   self.restarting = False                                                    │
│   148 │   │   │   except StopIteration:                                                          │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\loops\dataloader\evaluation_loop.py:110 in      │
│ advance                                                                                          │
│                                                                                                  │
│   107 │   │   )                                                                                  │
│   108 │   │   dl_max_batches = self._max_batches[dataloader_idx]                                 │
│   109 │   │                                                                                      │
│ ❱ 110 │   │   dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, sel   │
│   111 │   │                                                                                      │
│   112 │   │   # store batch level output per dataloader                                          │
│   113 │   │   self.outputs.append(dl_outputs)                                                    │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\loops\base.py:145 in run                        │
│                                                                                                  │
│   142 │   │   while not self.done:                                                               │
│   143 │   │   │   try:                                                                           │
│   144 │   │   │   │   self.on_advance_start(*args, **kwargs)                                     │
│ ❱ 145 │   │   │   │   self.advance(*args, **kwargs)                                              │
│   146 │   │   │   │   self.on_advance_end()                                                      │
│   147 │   │   │   │   self.restarting = False                                                    │
│   148 │   │   │   except StopIteration:                                                          │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\loops\epoch\evaluation_epoch_loop.py:122 in     │
│ advance                                                                                          │
│                                                                                                  │
│   119 │   │                                                                                      │
│   120 │   │   # lightning module methods                                                         │
│   121 │   │   with self.trainer.profiler.profile("evaluation_step_and_end"):                     │
│ ❱ 122 │   │   │   output = self._evaluation_step(batch, batch_idx, dataloader_idx)               │
│   123 │   │   │   output = self._evaluation_step_end(output)                                     │
│   124 │   │                                                                                      │
│   125 │   │   self.batch_progress.increment_processed()                                          │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\loops\epoch\evaluation_epoch_loop.py:217 in     │
│ _evaluation_step                                                                                 │
│                                                                                                  │
│   214 │   │   else:                                                                              │
│   215 │   │   │   self.trainer.lightning_module._current_fx_name = "validation_step"             │
│   216 │   │   │   with self.trainer.profiler.profile("validation_step"):                         │
│ ❱ 217 │   │   │   │   output = self.trainer.accelerator.validation_step(step_kwargs)             │
│   218 │   │                                                                                      │
│   219 │   │   return output                                                                      │
│   220                                                                                            │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\accelerators\accelerator.py:239 in              │
│ validation_step                                                                                  │
│                                                                                                  │
│   236 │   │   See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` fo   │
│   237 │   │   """                                                                                │
│   238 │   │   with self.precision_plugin.val_step_context():                                     │
│ ❱ 239 │   │   │   return self.training_type_plugin.validation_step(*step_kwargs.values())        │
│   240 │                                                                                          │
│   241 │   def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT   │
│   242 │   │   """The actual test step.                                                           │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py:2 │
│ 19 in validation_step                                                                            │
│                                                                                                  │
│   216 │   │   pass                                                                               │
│   217 │                                                                                          │
│   218 │   def validation_step(self, *args, **kwargs):                                            │
│ ❱ 219 │   │   return self.model.validation_step(*args, **kwargs)                                 │
│   220 │                                                                                          │
│   221 │   def test_step(self, *args, **kwargs):                                                  │
│   222 │   │   return self.model.test_step(*args, **kwargs)                                       │
│                                                                                                  │
│ C:\Users\jhluo\AppData\Local\Temp\ipykernel_8180\2205988259.py:35 in validation_step             │
│                                                                                                  │
│ [Errno 2] No such file or directory:                                                             │
│ 'C:\\Users\\jhluo\\AppData\\Local\\Temp\\ipykernel_8180\\2205988259.py'                          │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\torch\nn\modules\module.py:1110 in _call_impl                     │
│                                                                                                  │
│   1107 │   │   # this function, and just call forward.                                           │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1111 │   │   # Do not call functions when jit is used                                          │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ C:\Users\jhluo\AppData\Local\Temp\ipykernel_8180\2205988259.py:9 in forward                      │
│                                                                                                  │
│ [Errno 2] No such file or directory:                                                             │
│ 'C:\\Users\\jhluo\\AppData\\Local\\Temp\\ipykernel_8180\\2205988259.py'                          │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\torch\nn\modules\module.py:1110 in _call_impl                     │
│                                                                                                  │
│   1107 │   │   # this function, and just call forward.                                           │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1111 │   │   # Do not call functions when jit is used                                          │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ D:\anaconda3\lib\site-packages\transformers\models\roberta\modeling_roberta.py:1240 in forward   │
│                                                                                                  │
│   1237 │   │   │   │   │   loss = loss_fct(logits, labels)                                       │
│   1238 │   │   │   elif self.config.problem_type == "single_label_classification":               │
│   1239 │   │   │   │   loss_fct = CrossEntropyLoss()                                             │
│ ❱ 1240 │   │   │   │   loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))        │
│   1241 │   │   │   elif self.config.problem_type == "multi_label_classification":                │
│   1242 │   │   │   │   loss_fct = BCEWithLogitsLoss()                                            │
│   1243 │   │   │   │   loss = loss_fct(logits, labels)                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: shape '[-1, 3]' is invalid for input of size 16384

It seems you are trying to flatten the logits into 3 classes here:

logits.view(-1, self.num_labels)

as self.num_labels is 3. If you want to reduce the number of classes to 2, you might need to double check this variable.