TypeError: `Trainer.fit()` requires a `LightningModule`, got: NERTokenClassifier

Hello All,

This is rather a question based on imports, subclassing for torch.nn.Module with pl.LightningModule.

I am receiving an error, which highlights that Trainer.fit() requires a lightning module. Using this link Train a model (basic) — PyTorch Lightning 1.7.5 documentation, I defined a NERModel, subclassed from torch.nn.Module. Then, create a NERTokenClassifier subclassed from lightning - pl.LightningModule. Within init I set self.model equal to the NERModel. I provided code below, excuse the length.

Any advice on how to correct this?

Environment requirements

# requirements.txt
torch==2.0.1
transformers==4.29.2
polars==0.18.0
sklearn
seqeval==1.2.2
evaluate==0.4.0
lightning==1.9.0
pytorch-lightning==1.9.0
comet-ml==3.23.0
numpy
requests_kerberos
import logging
import os
import pathlib

import lightning as pl
import numpy as np
import torch
import transformers
from evaluate import load
from datasets import load_metric  # TODO: # from evaluate import load
import os
from argparse import ArgumentParser


class NERModel(torch.nn.Module):

    def __init__(self,
                 n_tags: int, dropout: float = 0.1, model_name: str = settings.MODELS["bert-ner-lg"],
                 **kwargs):
        #def __init__(self, conf, **kwargs):
        super().__init__()

        self.n_tags = n_tags
        self.dropout = dropout
        self.transformer = transformers.AutoModel.from_pretrained(model_name)
        # extract transformer name
        self.transformer_name = self.transformer.name_or_path
        # extract AutoConfig, from which relevant parameters can be extracted.
        self.transformer_config = transformers.AutoConfig.from_pretrained(model_name)

        self.dropout = torch.nn.Dropout(dropout)
        self.tags = torch.nn.Linear(self.transformer_config.hidden_size, n_tags)

    def forward(self,  batch) -> torch.Tensor:
        """Model Forward Iteration
        Args:
            input_ids (torch.Tensor): Input IDs.
            masks (torch.Tensor): Attention Masks.

        Returns:
            torch.Tensor: predicted values.
        """

        outputs = self.transformer(input_ids=batch["input_ids"],
                                   attention_mask=batch["attention_mask"])

        hidden_state = outputs[0]  # (bs, seq_len, dim)

        # apply drop-out
        outputs = self.dropout(hidden_state)

        # outputs for all labels/tags
        outputs = self.tags(outputs)

        return outputs



class NERTokenClassifier(pl.LightningModule):

    def __init__(self, n_tags: int, learning_rate: float = 0.0001 * 8, **kwargs):
        super().__init__(**kwargs)
        self.save_hyperparameters()

        self.n_tags = n_tags
        # Metrics
        self.metric = load("seqeval")
        self.model = NERModel(n_tags=self.n_tags)

    def training_step(self, batch, batch_nb):
        target_tags = batch["target_tags"]
        # fwd
        y_hat = self.model(batch)

        # loss
        loss_fct = torch.nn.CrossEntropyLoss()

        # Compute active loss so as to not compute loss of paddings
        active_loss = batch["attention_mask"].view(-1) == 1

        active_logits = y_hat.view(-1, self.n_tags)
        active_labels = torch.where(
            active_loss,
            target_tags.view(-1),
            torch.tensor(loss_fct.ignore_index).type_as(target_tags)
        )

        # Only compute loss on actual token predictions
        loss = loss_fct(active_logits, active_labels)

        # logs
        self.log_dict({"train_loss":loss}, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_nb):
        target_tags = batch["target_tags"]
        # fwd
        y_hat = self.model(batch)

        # loss
        loss_fct = torch.nn.CrossEntropyLoss()

        # Compute active loss so as to not compute loss of paddings
        active_loss = batch["attention_mask"].view(-1) == 1

        active_logits = y_hat.view(-1, self.n_tags)
        active_labels = torch.where(
            active_loss,
            target_tags.view(-1),
            torch.tensor(loss_fct.ignore_index).type_as(target_tags)
        )

        # Only compute loss on actual token predictions
        loss = loss_fct(active_logits, active_labels)

        metrics = self.compute_metrics([y_hat,target_tags])

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log_dict({"val_loss":loss, "val_f1":metrics["f1"], "val_accuracy":metrics["accuracy"],
                       "val_precision":metrics["precision"], "val_recall":metrics["recall"]}, prog_bar=True)
        return loss

    def test_step(self, batch, batch_nb):
        target_tags = batch["target_tags"]
        # fwd
        y_hat = self.model(batch)

        # loss
        loss_fct = torch.nn.CrossEntropyLoss()
        # Compute active loss so as to not compute loss of paddings
        active_loss = batch["attention_mask"].view(-1) == 1

        active_logits = y_hat.view(-1, self.n_tags)
        active_labels = torch.where(
            active_loss,
            target_tags.view(-1),
            torch.tensor(loss_fct.ignore_index).type_as(target_tags)
        )

        # Only compute loss on actual token predictions
        loss = loss_fct(active_logits, active_labels)
        metrics = self.compute_metrics([y_hat,target_tags])

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log_dict({"test_loss":loss, "test_f1":metrics["f1"], "test_accuracy":metrics["accuracy"],
                       "test_precision":metrics["precision"], "test_recall":metrics["recall"]}, prog_bar=True)
        return loss

    def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None):
        # fwd
        y_hat = self.model(batch)
        return {"logits":y_hat,
                "target_tags":batch["target_tags"],
                "input_ids":batch["input_ids"],
                "attention_mask":batch["attention_mask"]}

    # ---------------------
    # TRAINING SETUP
    # ---------------------
    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        optimizer = torch.optim.Adam([p for p in self.parameters() if p.requires_grad],
                                     lr=self.hparams.learning_rate,
                                     eps=1e-08)
        scheduler = {
            "scheduler": torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=2e-5,
                steps_per_epoch=len(self.trainer.datamodule.train_dataloader()),
                epochs=self.hparams.max_epochs),
            "interval": "step"  # called after each training step
        }
        #scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2e-5, total_steps=2000)
        #scheduler = StepLR(optimizer, step_size=1, gamma=0.2)
        #scheduler = ReduceLROnPlateau(optimizer, patience=0, factor=0.2)

        return [optimizer], [scheduler]


    @staticmethod
    def add_model_specific_args(parent_parser, root_dir):  # pragma: no-cover
        """
        Define parameters that only apply to this model
        """
        parser = ArgumentParser(parents=[parent_parser])

        # network params
        #parser.add_argument('--drop_prob', default=0.2, type=float)

        # data
        parser.add_argument("--data_root", default=os.path.join(root_dir, settings.DATA_DIRECTORY), type=str)

        # training params (opt)
        parser.add_argument("--learning_rate", default=2e-5, type=float, help = "type (default: %(default)f)")
        return parser
    # ---------------------
    # EVALUATE PERFORMANCE
    # ---------------------

    def compute_metrics(self,p):
        predictions, labels = p
        predictions = torch.argmax(predictions, dim=2)
        label_len = len(self.trainer.datamodule.tag_complete)
        label_list = self.trainer.datamodule.tag_encoder.inverse_transform(np.arange(label_len))

        # Remove ignored index (special tokens)
        true_predictions = [
            [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        true_labels = [
            [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]

        results = self.metric.compute(predictions=true_predictions, references=true_labels)
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

The Dataset and DataModule classes are below for full picture.

class NERDataSet(Dataset):
    """Generic NERDA DataSetReader"""

    def __init__(self,
                 examples,
                 tokenizer: transformers.PreTrainedTokenizer,
                 tag_encoder: sklearn.preprocessing.LabelEncoder,
                 label_all_tokens: bool = False
                 ) -> None:
        """Initialize DataSetReader
        Initializes DataSetReader that prepares and preprocesses
        DataSet for Named-Entity Recognition Task and training.
        Args:
            sentences (list): Sentences.
            tags (list): Named-Entity tags.
            transformer_tokenizer (transformers.PreTrainedTokenizer):
                tokenizer for transformer.
            transformer_config (transformers.PretrainedConfig): Config
                for transformer model.
            max_len (int): Maximum length of sentences after applying
                transformer tokenizer.
            tag_encoder (sklearn.preprocessing.LabelEncoder): Encoder
                for Named-Entity tags.
            tag_outside (str): Special Outside tag.
        """
        self.sentences = examples["sentences"]
        self.tags = examples["tags"]
        self.tokenizer = tokenizer
        self.tag_encoder = tag_encoder
        self.label_all_tokens = label_all_tokens

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, item):
        tags = self.tags[item]
        # encode tags and sentence words
        tags = self.tag_encoder.transform(tags)
        tokenized_inputs = self.tokenizer(self.sentences[item], truncation=True, is_split_into_words=True)

        word_ids = tokenized_inputs.word_ids()
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(tags[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.  A word could be split into two or more tokens occasionally depending on the model tokenizer
            else:
                label_ids.append(tags[word_idx] if self.label_all_tokens else -100)
            previous_word_idx = word_idx

        tokenized_inputs["target_tags"] = label_ids
        return tokenized_inputs


class NERDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 16, num_workers: int = 2, model_name: str = settings.MODELS["bert-ner-lg"]):
        super().__init__()

        # Defining batch size of our data
        self.batch_size = batch_size

        # Defining num_workers
        self.num_workers = num_workers

        # Defining Tokenizers
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

        self.label_pad_token_id = -100

    def prepare_data(self):
        self.train_data = get_data(path=settings.TRAIN)
        self.val_data = get_data(path=settings.VALID)
        self.test_data = get_data(path=settings.TEST)

        self.tag_complete = get_tag_scheme(path=settings.TRAIN)
        self.tag_encoder = sklearn.preprocessing.LabelEncoder()
        self.tag_encoder.fit(self.tag_complete)

    def setup(self, stage=None):
        # Loading the dataset
        self.train_dataset = NERDataSet(self.train_data, tokenizer=self.tokenizer, tag_encoder=self.tag_encoder, label_all_tokens=True)
        self.val_dataset = NERDataSet(self.val_data, tokenizer=self.tokenizer, tag_encoder=self.tag_encoder, label_all_tokens=True)
        self.test_dataset = NERDataSet(self.test_data, tokenizer=self.tokenizer, tag_encoder=self.tag_encoder, label_all_tokens=True)

    def custom_collate(self,features):
        label_name = "target_tags"
        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None

        batch = self.tokenizer.pad(
            features,
            padding=True,
            # Conversion to tensors will fail if we have labels as they are not of the same length yet.
            return_tensors="pt" if labels is None else None,
        )

        if labels is None:
            return batch

        sequence_length = torch.tensor(batch["input_ids"]).shape[1]
        padding_side = self.tokenizer.padding_side
        if padding_side == "right":
            batch[label_name] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
        else:
            batch[label_name] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]

        batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}

        return batch

    def train_dataloader(self):
        #dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        #return DataLoader(train_dataset, sampler=dist_sampler, batch_size=32) # For use in Multiple GPUs
        return DataLoader(self.train_dataset, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.custom_collate)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.custom_collate)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.custom_collate)

    def predict_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.custom_collate)


rewritten into torch; works fine. must be a bug in the above.