BiLSTM forward() - RuntimeError: shape '[-1, 38]' is invalid for input of size 1

Based on SO post.


forward() now needs to facilitate nn.LSTM(... bidirectional=True).

I’m basing my latest amendments on this disscuss.pytorch.org response.

Error

Error is based on mismatch of shapes.

Which data needs to be shaped for which layers?

I’m far out of my depths.

RuntimeError: shape '[-1, 38]' is invalid for input of size 1

Code

from argparse import ArgumentParser

import torchmetrics
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMClassifier(nn.Module):

    def __init__(self, 
        num_classes, 
        batch_size=10,
        embedding_dim=100, 
        hidden_dim=50, 
        vocab_size=128):

        super(LSTMClassifier, self).__init__()

        initrange = 0.1

        self.num_labels = num_classes
        n = len(self.num_labels)
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        
        self.num_layers = 1

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_embeddings.weight.data.uniform_(-initrange, initrange)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=self.num_layers, batch_first=True, bidirectional=True)  # !
        #self.classifier = nn.Linear(hidden_dim, self.num_labels[0])
        self.classifier = nn.Linear(2 * hidden_dim, self.num_labels[0])  # !


    def repackage_hidden(h):
        """Wraps hidden states in new Tensors, to detach them from their history."""

        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(repackage_hidden(v) for v in h)


    def forward(self, sentence, labels=None):
        embeds = self.word_embeddings(sentence)
        # lstm_out, _ = self.lstm(embeds)  # lstm_out - 2 tensors, _ - hidden layer
        lstm_out, hidden = self.lstm(embeds)
        
        # Calculate number of directions
        self.num_directions = 2 if self.lstm.bidirectional == True else 1
        
        # Extract last hidden state
        # final_state = hidden.view(self.num_layers, self.num_directions, self.batch_size, self.hidden_dim)[-1]
        final_state = hidden[0].view(self.num_layers, self.num_directions, self.batch_size, self.hidden_dim)[-1]
        # Handle directions
        final_hidden_state = None
        if self.num_directions == 1:
            final_hidden_state = final_state.squeeze(0)
        elif self.num_directions == 2:
            h_1, h_2 = final_state[0], final_state[1]
            # final_hidden_state = h_1 + h_2               # Add both states (requires changes to the input size of first linear layer + attention layer)
            final_hidden_state = torch.cat((h_1, h_2), 1)  # Concatenate both states
        
        print("len(final_hidden_state)", len(final_hidden_state))
        print("len(labels)", len(labels))
        print("final_hidden_state.shape", final_hidden_state.shape)
        print("labels", labels)
        
        self.linear_dims = [0]
        
        # Define set of fully connected layers (Linear Layer + Activation Layer) * #layers
        self.linears = nn.ModuleList()
        for i in range(0, len(self.linear_dims)-1):
            linear_layer = nn.Linear(self.linear_dims[i], self.linear_dims[i+1])
            self.init_weights(linear_layer)
            self.linears.append(linear_layer)
            if i == len(self.linear_dims) - 1:
                break  # no activation after output layer!!!
            self.linears.append(nn.ReLU())
        
        X = final_hidden_state
        
        # Push through linear layers
        for l in self.linears:
            X = l(X)
        
        # tag_space = self.classifier(hidden[:,0,:] + hidden[:,-1,:])  # !  # torch.flip(lstm_out[:,-1,:], [0, 1]) - 1 tensor
        #logits = F.log_softmax(final_hidden_state, dim=1)
        logits = F.cross_entropy(final_hidden_state, labels[0].view(-1))
        loss = None
        if labels:
            # print("len(logits.view(-1, self.num_labels[0]))", len(logits.view(-1, self.num_labels[0])))
            print("len(self.num_labels)", len(self.num_labels))
            print("self.num_labels[0]", self.num_labels[0])
            print("len(labels[0].view(-1))", len(labels[0].view(-1)))
            loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
        return loss, logits


class LSTMTaggerModel(pl.LightningModule):
    def __init__(
        self,
        num_classes,
        class_map,
        from_checkpoint=False,
        model_name='last.ckpt',
        learning_rate=3e-6,
        **kwargs,
    ):

        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.model = LSTMClassifier(num_classes=num_classes)
        # self.model.load_state_dict(torch.load(model_name), strict=False)  # !
        self.class_map = class_map
        self.num_classes = num_classes
        self.valid_acc = torchmetrics.Accuracy()
        self.valid_f1 = torchmetrics.F1()


    def forward(self, *input, **kwargs):
        return self.model(*input, **kwargs)

    def training_step(self, batch, batch_idx):
        x, y_true = batch
        loss, _ = self(x, labels=y_true)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y_true = batch
        _, y_pred = self(x, labels=y_true)
        preds = torch.argmax(y_pred, axis=1)
        self.valid_acc(preds, y_true[0])
        self.log('val_acc', self.valid_acc, prog_bar=True)
        self.valid_f1(preds, y_true[0])
        self.log('f1', self.valid_f1, prog_bar=True)     

    def configure_optimizers(self):
        'Prepare optimizer and schedule (linear warmup and decay)'
        opt = torch.optim.Adam(params=self.parameters(), lr=self.learning_rate)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
        return [opt], [sch]

    def training_epoch_end(self, training_step_outputs):
        avg_loss = torch.tensor([x['loss']
                                 for x in training_step_outputs]).mean()
        self.log('train_loss', avg_loss)
        print(f'###score: train_loss### {avg_loss}')

    def validation_epoch_end(self, val_step_outputs):
        acc = self.valid_acc.compute()
        f1 = self.valid_f1.compute()
        self.log('val_score', acc)
        self.log('f1', f1)
        print(f'###score: val_score### {acc}')

    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("OntologyTaggerModel")       
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--learning_rate", default=2e-3, type=float)
        return parent_parser

Traceback:

Global seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type           | Params
---------------------------------------------
0 | model     | LSTMClassifier | 77.4 K
1 | valid_acc | Accuracy       | 0     
2 | valid_f1  | F1             | 0     
---------------------------------------------
77.4 K    Trainable params
0         Non-trainable params
77.4 K    Total params
0.310     Total estimated model params size (MB)
Validation sanity check: 0it [00:00, ?it/s]
len(final_hidden_state) 10
len(labels) 1
final_hidden_state.shape torch.Size([10, 100])
labels [tensor([ 2, 31, 26, 37, 22,  5, 31, 36,  5, 10])]
len(self.num_labels) 1
self.num_labels[0] 38
len(labels[0].view(-1)) 10
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-3f817f701f20> in <module>
     11     """.split()
     12 
---> 13 run_training(args)

<ipython-input-5-bb0d8b014e32> in run_training(input)
     66         shutil.copyfile(labels_file_orig, labels_file_cp)
     67     trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=loggers)
---> 68     trainer.fit(model, dm)
     69     model_file = os.path.join(args.modeldir, 'last.ckpt')
     70     trainer.save_checkpoint(model_file, weights_only=True)

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    497 
    498         # dispath `start_training` or `start_testing` or `start_predicting`
--> 499         self.dispatch()
    500 
    501         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    544 
    545         else:
--> 546             self.accelerator.start_training(self)
    547 
    548     def train_or_test_or_predict(self):

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     71 
     72     def start_training(self, trainer):
---> 73         self.training_type_plugin.start_training(trainer)
     74 
     75     def start_testing(self, trainer):

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    112     def start_training(self, trainer: 'Trainer') -> None:
    113         # double dispatch to initiate the training loop
--> 114         self._results = trainer.run_train()
    115 
    116     def start_testing(self, trainer: 'Trainer') -> None:

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    605             self.progress_bar_callback.disable()
    606 
--> 607         self.run_sanity_check(self.lightning_module)
    608 
    609         # set stage for logging

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)
    858 
    859             # run eval step
--> 860             _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches)
    861 
    862             self.on_sanity_check_end()

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in run_evaluation(self, max_batches, on_epoch)
    723                 # lightning module methods
    724                 with self.profiler.profile("evaluation_step_and_end"):
--> 725                     output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
    726                     output = self.evaluation_loop.evaluation_step_end(output)
    727 

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py in evaluation_step(self, batch, batch_idx, dataloader_idx)
    164             model_ref._current_fx_name = "validation_step"
    165             with self.trainer.profiler.profile("validation_step"):
--> 166                 output = self.trainer.accelerator.validation_step(args)
    167 
    168         # capture any logged information

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py in validation_step(self, args)
    175 
    176         with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
--> 177             return self.training_type_plugin.validation_step(*args)
    178 
    179     def test_step(self, args):

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in validation_step(self, *args, **kwargs)
    129 
    130     def validation_step(self, *args, **kwargs):
--> 131         return self.lightning_module.validation_step(*args, **kwargs)
    132 
    133     def test_step(self, *args, **kwargs):

<ipython-input-15-6ef4e0993417> in validation_step(self, batch, batch_idx)
    130     def validation_step(self, batch, batch_idx):
    131         x, y_true = batch
--> 132         _, y_pred = self(x, labels=y_true)
    133         preds = torch.argmax(y_pred, axis=1)
    134         self.valid_acc(preds, y_true[0])

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-15-6ef4e0993417> in forward(self, *input, **kwargs)
    120 
    121     def forward(self, *input, **kwargs):
--> 122         return self.model(*input, **kwargs)
    123 
    124     def training_step(self, batch, batch_idx):

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-15-6ef4e0993417> in forward(self, sentence, labels)
     93             print("self.num_labels[0]", self.num_labels[0])
     94             print("len(labels[0].view(-1))", len(labels[0].view(-1)))
---> 95             loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
     96         return loss, logits
     97 

RuntimeError: shape '[-1, 38]' is invalid for input of size 1

My problem was 2 things.

One, I had to run classifier() before calculating cross_entropy().

Secondly, I had to pass X, final_hidden_layer.flatten().

X = final_hidden_state
        
# Push through linear layers
    for l in self.linears:
        X = l(X)

logits = self.classifier(X)

This achieves a working model. However, the first epoch’s validation score is 0%.

This will require further work.