How to make an LSTM Bidirectional?

Based on SO post. PyTorch GitHub advised me to post on here.

Goal: make LSTM self.classifier() learn from bidirectional layers.

# ! = code lines of interest

Question:
What changes to LSTMClassifier do I need to make, in order to have this LSTM work bidirectionally?


I think the problem is in forward(). It learns from the last state of LSTM neural network, by slicing:

tag_space = self.classifier(lstm_out[:,-1,:])

However, bidirectional changes the architecture and thus the output shape.

Do I need to sum up or concatenate the values of the 2 layers/ directions?


Installs:

!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
!pip -q install pytorch-lightning==1.2.7 torchmetrics awscli mlflow boto3 pycm
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchtext==0.10.0 -f https://download.pytorch.org/whl/cu111/torch_stable.html

Working Code:

from argparse import ArgumentParser

import torchmetrics
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMClassifier(nn.Module):

    def __init__(self, 
        num_classes, 
        batch_size=10,
        embedding_dim=100, 
        hidden_dim=50, 
        vocab_size=128):

        super(LSTMClassifier, self).__init__()

        initrange = 0.1

        self.num_labels = num_classes
        n = len(self.num_labels)
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_embeddings.weight.data.uniform_(-initrange, initrange)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True, bidirectional=True)  # !
        
        print("# !")
        
        bi_grus = torch.nn.GRU(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
        reverse_gru = torch.nn.GRU(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True, bidirectional=False)
        
        self.lstm.weight_ih_l0_reverse = bi_grus.weight_ih_l0_reverse
        self.lstm.weight_hh_l0_reverse = bi_grus.weight_hh_l0_reverse
        self.lstm.bias_ih_l0_reverse = bi_grus.bias_ih_l0_reverse
        self.lstm.bias_hh_l0_reverse = bi_grus.bias_hh_l0_reverse
        
        bi_output, bi_hidden = bi_grus()
        reverse_output, reverse_hidden = reverse_gru()
        
        print("# !")

        # self.classifier = nn.Linear(hidden_dim, self.num_labels[0])
        self.classifier = nn.Linear(2 * hidden_dim, self.num_labels[0])  # !


    def repackage_hidden(h):
        """Wraps hidden states in new Tensors, to detach them from their history."""

        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(repackage_hidden(v) for v in h)


    def forward(self, sentence, labels=None):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds)  # lstm_out - 2 tensors, _ - hidden layer
        print(lstm_out[:,-1,:])
        tag_space = self.classifier(lstm_out[:,-1,:] + lstm_out[:,-1,:])  # !  # lstm_out[:,-1,:] - 1 tensor
        logits = F.log_softmax(tag_space, dim=1)
        loss = None
        if labels:
            loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
        return loss, logits


class LSTMTaggerModel(pl.LightningModule):
    def __init__(
        self,
        num_classes,
        class_map,
        from_checkpoint=False,
        model_name='last.ckpt',
        learning_rate=3e-6,
        **kwargs,
    ):

        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.model = LSTMClassifier(num_classes=num_classes)
        self.model.load_state_dict(torch.load(model_name), strict=False)  # !
        self.class_map = class_map
        self.num_classes = num_classes
        self.valid_acc = torchmetrics.Accuracy()
        self.valid_f1 = torchmetrics.F1()


    def forward(self, *input, **kwargs):
        return self.model(*input, **kwargs)

    def training_step(self, batch, batch_idx):
        x, y_true = batch
        loss, _ = self(x, labels=y_true)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y_true = batch
        _, y_pred = self(x, labels=y_true)
        preds = torch.argmax(y_pred, axis=1)
        self.valid_acc(preds, y_true[0])
        self.log('val_acc', self.valid_acc, prog_bar=True)
        self.valid_f1(preds, y_true[0])
        self.log('f1', self.valid_f1, prog_bar=True)     

    def configure_optimizers(self):
        'Prepare optimizer and schedule (linear warmup and decay)'
        opt = torch.optim.Adam(params=self.parameters(), lr=self.learning_rate)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
        return [opt], [sch]

    def training_epoch_end(self, training_step_outputs):
        avg_loss = torch.tensor([x['loss']
                                 for x in training_step_outputs]).mean()
        self.log('train_loss', avg_loss)
        print(f'###score: train_loss### {avg_loss}')

    def validation_epoch_end(self, val_step_outputs):
        acc = self.valid_acc.compute()
        f1 = self.valid_f1.compute()
        self.log('val_score', acc)
        self.log('f1', f1)
        print(f'###score: val_score### {acc}')

    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("OntologyTaggerModel")       
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--learning_rate", default=2e-3, type=float)
        return parent_parser

Traceback:

RuntimeError                              Traceback (most recent call last)
<ipython-input-15-b94d572a1b68> in <module>
     11     """.split()
     12 
---> 13 run_training(args)

<ipython-input-6-bb0d8b014e32> in run_training(input)
     54     elif args.checkpointfile:
     55         file_path = os.path.join(args.traindir, args.checkpointfile)
---> 56         model = LSTMTaggerModel.load_from_checkpoint(file_path)
     57     else:
     58         model = LSTMTaggerModel(**vars(args), num_classes=dm.num_classes, class_map=dm.class_map)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
    155         checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
    156 
--> 157         model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
    158         return model
    159 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, strict, **cls_kwargs_new)
    203 
    204         # load the state_dict on the model automatically
--> 205         model.load_state_dict(checkpoint['state_dict'], strict=strict)
    206 
    207         return model

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1405         if len(error_msgs) > 0:
   1406             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1407                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1408         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1409 

RuntimeError: Error(s) in loading state_dict for LSTMTaggerModel:
	Missing key(s) in state_dict: "model.lstm.weight_ih_l0_reverse", "model.lstm.weight_hh_l0_reverse", "model.lstm.bias_ih_l0_reverse", "model.lstm.bias_hh_l0_reverse". 
	size mismatch for model.classifier.weight: copying a param with shape torch.Size([38, 50]) from checkpoint, the shape in current model is torch.Size([38, 100]).

Key Error:

size mismatch for model.classifier.weight: copying a param with shape torch.Size([38, 50]) from checkpoint, the shape in current model is torch.Size([38, 100]).

last.ckpt is torch.Size([38, 50]), but my code is torch.Size([38, 100]).

Versions

PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: N/A

OS: Amazon Linux AMI 2018.03 (x86_64)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-28)
Clang version: Could not collect
CMake version: version 3.22.0
Libc version: glibc-2.10

Python version: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:53)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-4.14.252-131.483.amzn1.x86_64-x86_64-with-glibc2.10
Is CUDA available: N/A
CUDA runtime version: 10.0.130
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/local/cuda-10.0/lib64/libcudnn.so.7.5.1
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.4
[conda] numpy                     1.21.4           py37h31617e3_0    conda-forge

This is my first post :blush:

The shape mismatch error and the missing keys are raised, since you are trying to load a state_dict from an LSTM module with bidirectional=False to another one where bidirectional=True was set as already described in your StackOverflow cross-post and as seen in this example:

lstm = nn.LSTM(input_size=2, hidden_size=3, num_layers=2, bidirectional=False)
sd = lstm.state_dict()

lstm_bidirectional = nn.LSTM(input_size=2, num_layers=2, hidden_size=3, bidirectional=True)
lstm_bidirectional.load_state_dict(sd)
# > RuntimeError: Error(s) in loading state_dict for LSTM:
#	Missing key(s) in state_dict: "weight_ih_l0_reverse", "weight_hh_l0_reverse", "bias_ih_l0_reverse", "bias_hh_l0_reverse", "weight_ih_l1_reverse", "weight_hh_l1_reverse", "bias_ih_l1_reverse", "bias_hh_l1_reverse". 
#   size mismatch for weight_ih_l1: copying a param with shape torch.Size([12, 3]) from checkpoint, the shape in current model is torch.Size([12, 6]).

The SO cross-post also mentioned workarounds.

I’ve got it running. Check above code. Why is the val score 0%?

I’ve updated the main line in question:
tag_space = self.classifier(lstm_out[:,-1,:] + torch.flip(lstm_out[:,-1,:], [0, 1])) # ! # lstm_out[:,-1,:] - 1 tensor

Instead of lstm_out, use the last hidden state. In other words, instead of

lstm_out, _ = self.lstm(embeds)

do

lstm_out, hidden = self.lstm(embeds)

And use hidden as it contains the last hidden state with respect to both directions. It’s much more convenient to use. If you use lstm_out, the last hidden state of the forward direction is at index -1, and the last hidden state of the backward direction is at index 0 (w.r.t. to the correct dimension of the tensor).

Note that you still have to use view() or something on hidden to get the correct hidden state (e.g., in case you have multiple layers). You can have a look at my code here, the important snippet is – it’s a bit verbose since I support both GRU/LSTM and uni/bidirectional:

# Extract last hidden state
if self.params.rnn_type == RnnType.GRU:
    final_state = self.hidden.view(self.params.num_layers, self.num_directions, batch_size, self.params.rnn_hidden_dim)[-1]
elif self.params.rnn_type == RnnType.LSTM:
    final_state = self.hidden[0].view(self.params.num_layers, self.num_directions, batch_size, self.params.rnn_hidden_dim)[-1]
# Handle directions
final_hidden_state = None
if self.num_directions == 1:
    final_hidden_state = final_state.squeeze(0)
elif self.num_directions == 2:
    h_1, h_2 = final_state[0], final_state[1]
    # final_hidden_state = h_1 + h_2               # Add both states (requires changes to the input size of first linear layer + attention layer)
    final_hidden_state = torch.cat((h_1, h_2), 1)  # Concatenate both states
1 Like

This is definitely brining me closer!

I just get this small error now:

<ipython-input-17-542f29e75b1a> in forward(self, sentence, labels)
     67         loss = None
     68         if labels:
---> 69             loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
     70         return loss, logits
     71 

RuntimeError: shape '[-1, 38]' is invalid for input of size 1000

Will update post @vdw