TorchScript inference slower than default torch model

We plan to deploy the asr model using TorchScript format. But during the deployment, we found that our inference speed in TorchScript slower than in default PyTorch format(Both on python). So I wonder why cound this happen? In my understanding, the inference speed should be equal or faster if use TorchScript.

seconds previous current
batch1 encode 0.28 0.25 0.27 0.30 0.65 0.68 0.64
batch2 encode 0.25 0.28 0.27 0.24 1.04 1.08 1.13

def len_to_mask(lens):
    max_len = lens.max()
    return torch.arange(max_len, device=lens.device)[None, :] < lens[:, None]

class Encoder(nn.Module):
    def __init__(self, d_input, d_model, d_inner, n_layer, n_head, n_kernel=25,
            attn_mode='normal', rel_pos=False, dropout=0.1, layer_drop=0.,
            time_ds=1, use_cnn=False, time_kn=3, time_std=2, freq_kn=3, freq_std=2, chunk_size=0):
        super().__init__()

        self.time_kn, self.time_std = time_kn, time_std
        if use_cnn:
            cnn = [nn.Conv2d(1, 32, kernel_size=(3,freq_kn), stride=(2,freq_std)), Swish(),
                   nn.Conv2d(32, 32, kernel_size=(time_kn,freq_kn), stride=(time_std,freq_std)), Swish()]
            self.cnn = nn.Sequential(*cnn)
            d_input = ((((d_input - freq_kn) // freq_std + 1) - freq_kn) // freq_std + 1)*32
        else:
            self.cnn = None

        self.emb = XavierLinear(d_input, d_model)
        self.drop = nn.Dropout(dropout)
        self.rel_pos = rel_pos
        assert self.rel_pos == False
        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model, d_inner, n_head, n_kernel, dropout, rel_pos)
            for _ in range(n_layer)])

        self.norm = nn.LayerNorm(d_model)
        self.layer_drop = layer_drop
        self.attn_mode = attn_mode
        self.chunk_size = chunk_size

    def get_mask(self, mode:str, src_mask, chunk_size:int, static_chunk_size:int=16):
        lt = src_mask.size(1)
        slf_mask = src_mask.eq(0)
        slf_mask = slf_mask.unsqueeze(1).expand(-1, lt, -1) # b x lq x lk

        if mode == 'unidrection':
            tri_mask = torch.ones((lt, lt), device=src_mask.device, dtype=torch.uint8)
            tri_mask = torch.triu(tri_mask, diagonal=1)
            tri_mask = tri_mask.unsqueeze(0).expand(src_mask.size(0), -1, -1)
            slf_mask = (slf_mask + tri_mask).gt(0)

        elif mode == 'static_chunk':
            chunk_masks = subsequent_chunk_mask(lt, static_chunk_size,-1,
                                                slf_mask.device)  # (L, L)
            chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
            slf_mask = slf_mask & chunk_masks  # (B, L, L)

        elif mode == 'dynamic_chunk':
            if chunk_size < 0:
                chunk_size = lt
            elif chunk_size > 0:
                chunk_size = chunk_size
            else:
                # chunk size is either [1, 25] or full context(max_len).
                # Since we use 4 times subsampling and allow up to 1s(100 frames)
                # delay, the maximum frame is 100 / 4 = 25.
                chunk_size = torch.randint(1, lt, (1, )).item()
            
                if chunk_size > lt // 2:
                    chunk_size = lt
                else:
                    chunk_size = chunk_size % 25 + 1
            
            chunk_masks = subsequent_chunk_mask(lt, chunk_size,
                                                -1,
                                                slf_mask.device)  # (L, L)
            chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
            slf_mask = slf_mask & chunk_masks  # (B, L, L)

        else:
            assert mode == 'normal'
        return slf_mask

    def forward(self, src_seq, src_mask):
        # -- Forward
        if self.cnn is not None:
            src_seq = src_seq.unsqueeze(1)
            src_seq = self.cnn(src_seq)
            src_seq = src_seq.permute(0, 2, 1, 3).contiguous()
            src_seq = src_seq.view(src_seq.size(0), src_seq.size(1), -1)
            if src_mask is not None:
                lens = src_mask.sum(-1)
                lens = (((lens - 3) // 2 + 1) - self.time_kn) // self.time_std + 1
                src_mask = len_to_mask(lens).to(src_mask.dtype)

        enc_out = src_seq if self.emb is None else self.drop(self.emb(src_seq))

         # -- Prepare masks
        slf_mask = self.get_mask(self.attn_mode, src_mask, self.chunk_size)

        nl, mask = len(self.layer_stack), src_mask.unsqueeze(-1)
        for l, enc_layer in enumerate(self.layer_stack):
            drop_level = (l+1.) * self.layer_drop / nl
            enc_out = enc_layer(enc_out, src_mask.unsqueeze(-1), slf_mask, drop_level)
        enc_out = self.norm(enc_out) * mask

        return enc_out, src_mask

class RNNTransducer(nn.Module):
    """
    RNNT transducer model
    """
    def __init__(self, enc_name : str, 
                encoder_configs : dict, 
                dec_name: str,
                decoder_configs: dict,
                jointnet_name: str,
                jointnet_configs: dict,
                lossname: str = "warprnnt",
                pretrained_encoder_path = None
                ):
        super().__init__()
        assert enc_name.lower() =="conformer"
        self.encoder = Encoder(**encoder_configs)
        # ...

    @torch.jit.export
    def encode(self, inputs: Tensor, inputs_mask: Tensor):
        encoder_outputs, outputs_mask= self.encoder(inputs, inputs_mask)
        return encoder_outputs, outputs_mask

Could you share how the end-to-end timing is being done to arrive at these measurements? I’m not able to understand the meaning of seconds previous current results.

Looks like something wrong with the text format.
seconds means the seconds model takes to finish the inference. previous means using default PyTorch format. current means the TorchScript model format. And the inference is done on Python.
I don’t know why the inference is much slower using TorchScript.

seconds default PyTorch
batch1 encode 0.28 0.25 0.27 0.30
batch2 encode 0.25 0.28 0.27 0.24

seconds TorchScript
batch1 encode 0.65 0.68 0.64
batch2 encode 1.04 1.08 1.13

Hi, maybe I should @eqy

Could you share how the timing is being done (e.g, profiler, or manually with torch.cuda.synchronize?).

@eqy The timing is done using time.time() because the inference is carried out on CPU.

At this point I would check if it would be possible to narrow down the slowdown to some specific part of the model (e.g., using a profiler or comparing individual layers).

Thank you! For now, I find this situation is much due to model warm-up. Performance will be steady and fast after warm-up.