Bidirectional LSTM gets val_score: 0.0

Goal: Optimise and better understand BiLSTM

I’ve a working BiLSTM. However, the 1st epoch’s val_score: 0.0.

I assume this blunt problem is caused by my mishandling of training with this additional layer.

Questions:

  • What might cause such a problem with specifically BiLSTMs?
  • What have I not/ incorrectly implemented?

Code:

from argparse import ArgumentParser

import torchmetrics
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMClassifier(nn.Module):

    def __init__(self, 
        num_classes, 
        batch_size=10,
        embedding_dim=100, 
        hidden_dim=50, 
        vocab_size=128):

        super(LSTMClassifier, self).__init__()

        initrange = 0.1

        self.num_labels = num_classes
        n = len(self.num_labels)
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        
        self.num_layers = 1

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_embeddings.weight.data.uniform_(-initrange, initrange)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=self.num_layers, batch_first=True, bidirectional=True)  # !
        #self.classifier = nn.Linear(hidden_dim, self.num_labels[0])
        self.classifier = nn.Linear(2 * hidden_dim, self.num_labels[0])  # !


    def repackage_hidden(h):
        """Wraps hidden states in new Tensors, to detach them from their history."""

        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(repackage_hidden(v) for v in h)


    def forward(self, sentence, labels=None):
        embeds = self.word_embeddings(sentence)
        # lstm_out, _ = self.lstm(embeds)  # lstm_out - 2 tensors, _ - hidden layer
        lstm_out, hidden = self.lstm(embeds)
        
        # Calculate number of directions
        self.num_directions = 2 if self.lstm.bidirectional == True else 1
        
        # Extract last hidden state
        # final_state = hidden.view(self.num_layers, self.num_directions, self.batch_size, self.hidden_dim)[-1]
        final_state = hidden[0].view(self.num_layers, self.num_directions, self.batch_size, self.hidden_dim)[-1]
        # Handle directions
        final_hidden_state = None
        if self.num_directions == 1:
            final_hidden_state = final_state.squeeze(0)
        elif self.num_directions == 2:
            h_1, h_2 = final_state[0], final_state[1]
            # final_hidden_state = h_1 + h_2               # Add both states (requires changes to the input size of first linear layer + attention layer)
            final_hidden_state = torch.cat((h_1, h_2), 1)  # Concatenate both states
        
        self.linear_dims = [0]
        
        # Define set of fully connected layers (Linear Layer + Activation Layer) * #layers
        self.linears = nn.ModuleList()
        for i in range(0, len(self.linear_dims)-1):
            linear_layer = nn.Linear(self.linear_dims[i], self.linear_dims[i+1])
            self.init_weights(linear_layer)
            self.linears.append(linear_layer)
            if i == len(self.linear_dims) - 1:
                break  # no activation after output layer!!!
            self.linears.append(nn.ReLU())
        
        X = final_hidden_state
        
        # Push through linear layers
        for l in self.linears:
            X = l(X)

        print('type(X)', type(X))
        print('len(X)', len(X))
        print('X.shape', X.shape)
        
        print('type(labels)', type(labels))
        print('len(labels)', len(labels))
        print('labels', labels)
        
        print('type(hidden[0])', type(hidden[0]))  # hidden[0] - tensor
        print('len(hidden[0])', len(hidden[0]))
        print('hidden[0].shape', hidden[0].shape)
        print('hidden[0]', labels)

        
        logits = self.classifier(X)  # !  # torch.flip(lstm_out[:,-1,:], [0, 1]) - 1 tensor
        print('type(logits)', type(logits))
        print('len(logits)', len(logits))
        print('logits.shape', logits.shape)
        
        loss = None
        if labels:
            print("len(self.num_labels)", len(self.num_labels))
            print("self.num_labels[0]", self.num_labels[0])
            print("len(labels[0].view(-1))", len(labels[0].view(-1)))
            loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
            
        return loss, logits


class LSTMTaggerModel(pl.LightningModule):
    def __init__(
        self,
        num_classes,
        class_map,
        from_checkpoint=False,
        model_name='last.ckpt',
        learning_rate=3e-6,
        **kwargs,
    ):

        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.model = LSTMClassifier(num_classes=num_classes)
        # self.model.load_state_dict(torch.load(model_name), strict=False)  # !
        self.class_map = class_map
        self.num_classes = num_classes
        self.valid_acc = torchmetrics.Accuracy()
        self.valid_f1 = torchmetrics.F1()


    def forward(self, *input, **kwargs):
        return self.model(*input, **kwargs)

    def training_step(self, batch, batch_idx):
        x, y_true = batch
        loss, _ = self(x, labels=y_true)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y_true = batch
        _, y_pred = self(x, labels=y_true)
        preds = torch.argmax(y_pred, axis=1)
        self.valid_acc(preds, y_true[0])
        self.log('val_acc', self.valid_acc, prog_bar=True)
        self.valid_f1(preds, y_true[0])
        self.log('f1', self.valid_f1, prog_bar=True)     

    def configure_optimizers(self):
        'Prepare optimizer and schedule (linear warmup and decay)'
        opt = torch.optim.Adam(params=self.parameters(), lr=self.learning_rate)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
        return [opt], [sch]

    def training_epoch_end(self, training_step_outputs):
        avg_loss = torch.tensor([x['loss']
                                 for x in training_step_outputs]).mean()
        self.log('train_loss', avg_loss)
        print(f'###score: train_loss### {avg_loss}')

    def validation_epoch_end(self, val_step_outputs):
        acc = self.valid_acc.compute()
        f1 = self.valid_f1.compute()
        self.log('val_score', acc)
        self.log('f1', f1)
        print(f'###score: val_score### {acc}')

    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("OntologyTaggerModel")       
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--learning_rate", default=2e-3, type=float)
        return parent_parser

Traceback:

print statements repeat.

Global seed set to 42
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/deprecate/deprecation.py:115: FutureWarning: The `F1` was deprecated since v0.7 in favor of `torchmetrics.classification.f_beta.F1Score`. It will be removed in v0.8.
  stream(template_mgs % msg_args)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: ModelCheckpoint(save_last=True, monitor=None) is a redundant configuration. You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).
  warnings.warn(*args, **kwargs)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type           | Params
---------------------------------------------
0 | model     | LSTMClassifier | 77.4 K
1 | valid_acc | Accuracy       | 0     
2 | valid_f1  | F1             | 0     
---------------------------------------------
77.4 K    Trainable params
0         Non-trainable params
77.4 K    Total params
0.310     Total estimated model params size (MB)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.
  warnings.warn(*args, **kwargs)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Validation sanity check: 0it [00:00, ?it/s]
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric Accuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric F1 was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)

type(X) <class 'torch.Tensor'>
len(X) 10
X.shape torch.Size([10, 100])
type(labels) <class 'list'>
len(labels) 1
labels [tensor([ 2, 31, 26, 37, 22,  5, 31, 36,  5, 10])]
type(hidden[0]) <class 'torch.Tensor'>
len(hidden[0]) 2
hidden[0].shape torch.Size([2, 10, 50])
hidden[0] [tensor([ 2, 31, 26, 37, 22,  5, 31, 36,  5, 10])]
type(logits) <class 'torch.Tensor'>
len(logits) 10
logits.shape torch.Size([10, 38])
len(self.num_labels) 1
self.num_labels[0] 38
len(labels[0].view(-1)) 10
...

The code snippet below – maybe some more – belongs in the __init__ method. Right now, you re-create the last linear layers in each call for forward.