Wrapping the T5 generate function with @torch.jit.script

Hi,

Main motivation: I’m trying to wrap a T5 model generate functionality with TorchScript.

Issue:
I am trying to wrap the T5Model’s generate() function from the t5_demo notebook function with @torch.jit.script. When doing so, I an error 'Unknown type name ‘T5Model’.

Here is the relevant code:

import torch
from torch import Tensor
import torch.nn.functional as F
from torchtext.prototype.models import T5Model, T5_BASE_GENERATION


def beam_search(
    beam_size: int,
    step: int,
    bsz: int,
    decoder_output: Tensor,
    decoder_tokens: Tensor,
    scores: Tensor,
    incomplete_sentences: Tensor,
):
    probs = F.log_softmax(decoder_output[:, -1], dim=-1)
    top = torch.topk(probs, beam_size)

    # N is number of sequences in decoder_tokens, L is length of sequences, B is beam_size
    # decoder_tokens has shape (N,L) -> (N,B,L)
    # top.indices has shape (N,B) - > (N,B,1)
    # x has shape (N,B,L+1)
    # note that when step == 1, N = batch_size, and when step > 1, N = batch_size * beam_size
    x = torch.cat([decoder_tokens.unsqueeze(1).repeat(1, beam_size, 1), top.indices.unsqueeze(-1)], dim=-1)

    # beams are first created for a given sequence
    if step == 1:
        # x has shape (batch_size, B, L+1) -> (batch_size * B, L+1)
        # new_scores has shape (batch_size,B)
        # incomplete_sentences has shape (batch_size * B) = (N)
        new_decoder_tokens = x.view(-1, step + 1)
        new_scores = top.values
        new_incomplete_sentences = incomplete_sentences

    # beams already exist, want to expand each beam into possible new tokens to add
    # and for all expanded beams beloning to the same sequences, choose the top k
    else:
        # scores has shape (batch_size,B) -> (N,1) -> (N,B)
        # top.values has shape (N,B)
        # new_scores has shape (N,B) -> (batch_size, B^2)
        new_scores = (scores.view(-1, 1).repeat(1, beam_size) + top.values).view(bsz, -1)

        # v, i have shapes (batch_size, B)
        v, i = torch.topk(new_scores, beam_size)

        # x has shape (N,B,L+1) -> (batch_size, B, L+1)
        # i has shape (batch_size, B) -> (batch_size, B, L+1)
        # new_decoder_tokens has shape (batch_size, B, L+1) -> (N, L)
        x = x.view(bsz, -1, step + 1)
        new_decoder_tokens = x.gather(index=i.unsqueeze(-1).repeat(1, 1, step + 1), dim=1).view(-1, step + 1)

        # need to update incomplete sentences in case one of the beams was kicked out
        # y has shape (N) -> (N, 1) -> (N, B) -> (batch_size, B^2)
        y = incomplete_sentences.unsqueeze(-1).repeat(1, beam_size).view(bsz, -1)

        # now can use i to extract those beams that were selected
        # new_incomplete_sentences has shape (batch_size, B^2) -> (batch_size, B) -> (N, 1) -> N
        new_incomplete_sentences = y.gather(index=i, dim=1).view(bsz * beam_size, 1).squeeze(-1)

        # new_scores has shape (batch_size, B)
        new_scores = v

    return new_decoder_tokens, new_scores, new_incomplete_sentences


@torch.jit.script
def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor:

    # pass tokens through encoder
    bsz = encoder_tokens.size(0)
    encoder_padding_mask = encoder_tokens.eq(model.padding_idx)
    encoder_embeddings = model.dropout1(model.token_embeddings(encoder_tokens))
    encoder_output = model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0]

    encoder_output = model.norm1(encoder_output)
    encoder_output = model.dropout2(encoder_output)

    # initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence
    decoder_tokens = torch.ones((bsz, 1), dtype=torch.long) * model.padding_idx
    scores = torch.zeros((bsz, beam_size))

    # mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet
    incomplete_sentences = torch.ones(bsz * beam_size, dtype=torch.long)

    # iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token
    for step in range(model.config.max_seq_len):

        if step == 1:
            # duplicate and order encoder output so that each beam is treated as its own independent sequence
            new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
            new_order = new_order.to(encoder_tokens.device).long()
            encoder_output = encoder_output.index_select(0, new_order)
            encoder_padding_mask = encoder_padding_mask.index_select(0, new_order)

        # causal mask and padding mask for decoder sequence
        tgt_len = decoder_tokens.shape[1]
        decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
        decoder_padding_mask = decoder_tokens.eq(model.padding_idx)

        # T5 implemention uses padding idx to start sequence. Want to ignore this when masking
        decoder_padding_mask[:, 0] = False

        # pass decoder sequence through decoder
        decoder_embeddings = model.dropout3(model.token_embeddings(decoder_tokens))
        decoder_output = model.decoder(
            decoder_embeddings,
            memory=encoder_output,
            tgt_mask=decoder_mask,
            tgt_key_padding_mask=decoder_padding_mask,
            memory_key_padding_mask=encoder_padding_mask,
        )[0]

        decoder_output = model.norm2(decoder_output)
        decoder_output = model.dropout4(decoder_output)
        decoder_output = decoder_output * (model.config.embedding_dim ** -0.5)
        decoder_output = model.lm_head(decoder_output)

        decoder_tokens, scores, incomplete_sentences = beam_search(
            beam_size, step + 1, bsz, decoder_output, decoder_tokens, scores, incomplete_sentences
        )
        # ignore newest tokens for sentences that are already complete
        decoder_tokens[:, -1] *= incomplete_sentences

        # update incomplete_sentences to remove those that were just ended
        incomplete_sentences = incomplete_sentences - (decoder_tokens[:, -1] == eos_idx).long()

        # early stop if all sentences have been ended
        if (incomplete_sentences == 0).all():
            break

    # take most likely sequence
    decoder_tokens = decoder_tokens.view(bsz, beam_size, -1)[:, 0, :]
    return decoder_tokens

And the error output is:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [3], line 67
     61         new_scores = v
     63     return new_decoder_tokens, new_scores, new_incomplete_sentences
     66 @torch.jit.script
---> 67 def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor:
     68 
     69     # pass tokens through encoder
     70     bsz = encoder_tokens.size(0)
     71     encoder_padding_mask = encoder_tokens.eq(model.padding_idx)

File ~/venv/miagi/lib/python3.9/site-packages/torch/jit/_script.py:1343, in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1341 if _rcb is None:
   1342     _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
-> 1343 fn = torch._C._jit_script_compile(
   1344     qualified_name, ast, _rcb, get_default_args(obj)
   1345 )
   1346 # Forward docstrings
   1347 fn.__doc__ = obj.__doc__

RuntimeError: 
Unknown type name 'T5Model':
  File "/var/folders/1l/2ssvf7hd1mj2r_pp0hjg_fqh0000gn/T/ipykernel_64584/592630873.py", line 67
@torch.jit.script
def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor:
                                                          ~~~~~~~ <--- HERE

    # pass tokens through encoder

Any idea why this is happening?
Is there a code example on how to wrap the full package (tokenizer, model and generate functionality) as Torchscript?

For my first question, I see that models like T5Model are not supported by TorchScript.

But the motivation remains: is there a way of wrapping the full package (tokenizer, model and generate functionality) as Torchscript?

I have managed to wrap the tokenizer, model and generate functionality as Torchscript. I will post an example here soon

Code suggestion for this can be found in the notebook added in this pull request: