We plan to deploy the asr model using TorchScript format. But during the deployment, we found that our inference speed in TorchScript slower than in default PyTorch format(Both on python). So I wonder why cound this happen? In my understanding, the inference speed should be equal or faster if use TorchScript.
seconds previous current
batch1 encode 0.28 0.25 0.27 0.30 0.65 0.68 0.64
batch2 encode 0.25 0.28 0.27 0.24 1.04 1.08 1.13
def len_to_mask(lens):
max_len = lens.max()
return torch.arange(max_len, device=lens.device)[None, :] < lens[:, None]
class Encoder(nn.Module):
def __init__(self, d_input, d_model, d_inner, n_layer, n_head, n_kernel=25,
attn_mode='normal', rel_pos=False, dropout=0.1, layer_drop=0.,
time_ds=1, use_cnn=False, time_kn=3, time_std=2, freq_kn=3, freq_std=2, chunk_size=0):
super().__init__()
self.time_kn, self.time_std = time_kn, time_std
if use_cnn:
cnn = [nn.Conv2d(1, 32, kernel_size=(3,freq_kn), stride=(2,freq_std)), Swish(),
nn.Conv2d(32, 32, kernel_size=(time_kn,freq_kn), stride=(time_std,freq_std)), Swish()]
self.cnn = nn.Sequential(*cnn)
d_input = ((((d_input - freq_kn) // freq_std + 1) - freq_kn) // freq_std + 1)*32
else:
self.cnn = None
self.emb = XavierLinear(d_input, d_model)
self.drop = nn.Dropout(dropout)
self.rel_pos = rel_pos
assert self.rel_pos == False
self.layer_stack = nn.ModuleList([
EncoderLayer(d_model, d_inner, n_head, n_kernel, dropout, rel_pos)
for _ in range(n_layer)])
self.norm = nn.LayerNorm(d_model)
self.layer_drop = layer_drop
self.attn_mode = attn_mode
self.chunk_size = chunk_size
def get_mask(self, mode:str, src_mask, chunk_size:int, static_chunk_size:int=16):
lt = src_mask.size(1)
slf_mask = src_mask.eq(0)
slf_mask = slf_mask.unsqueeze(1).expand(-1, lt, -1) # b x lq x lk
if mode == 'unidrection':
tri_mask = torch.ones((lt, lt), device=src_mask.device, dtype=torch.uint8)
tri_mask = torch.triu(tri_mask, diagonal=1)
tri_mask = tri_mask.unsqueeze(0).expand(src_mask.size(0), -1, -1)
slf_mask = (slf_mask + tri_mask).gt(0)
elif mode == 'static_chunk':
chunk_masks = subsequent_chunk_mask(lt, static_chunk_size,-1,
slf_mask.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
slf_mask = slf_mask & chunk_masks # (B, L, L)
elif mode == 'dynamic_chunk':
if chunk_size < 0:
chunk_size = lt
elif chunk_size > 0:
chunk_size = chunk_size
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, lt, (1, )).item()
if chunk_size > lt // 2:
chunk_size = lt
else:
chunk_size = chunk_size % 25 + 1
chunk_masks = subsequent_chunk_mask(lt, chunk_size,
-1,
slf_mask.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
slf_mask = slf_mask & chunk_masks # (B, L, L)
else:
assert mode == 'normal'
return slf_mask
def forward(self, src_seq, src_mask):
# -- Forward
if self.cnn is not None:
src_seq = src_seq.unsqueeze(1)
src_seq = self.cnn(src_seq)
src_seq = src_seq.permute(0, 2, 1, 3).contiguous()
src_seq = src_seq.view(src_seq.size(0), src_seq.size(1), -1)
if src_mask is not None:
lens = src_mask.sum(-1)
lens = (((lens - 3) // 2 + 1) - self.time_kn) // self.time_std + 1
src_mask = len_to_mask(lens).to(src_mask.dtype)
enc_out = src_seq if self.emb is None else self.drop(self.emb(src_seq))
# -- Prepare masks
slf_mask = self.get_mask(self.attn_mode, src_mask, self.chunk_size)
nl, mask = len(self.layer_stack), src_mask.unsqueeze(-1)
for l, enc_layer in enumerate(self.layer_stack):
drop_level = (l+1.) * self.layer_drop / nl
enc_out = enc_layer(enc_out, src_mask.unsqueeze(-1), slf_mask, drop_level)
enc_out = self.norm(enc_out) * mask
return enc_out, src_mask
class RNNTransducer(nn.Module):
"""
RNNT transducer model
"""
def __init__(self, enc_name : str,
encoder_configs : dict,
dec_name: str,
decoder_configs: dict,
jointnet_name: str,
jointnet_configs: dict,
lossname: str = "warprnnt",
pretrained_encoder_path = None
):
super().__init__()
assert enc_name.lower() =="conformer"
self.encoder = Encoder(**encoder_configs)
# ...
@torch.jit.export
def encode(self, inputs: Tensor, inputs_mask: Tensor):
encoder_outputs, outputs_mask= self.encoder(inputs, inputs_mask)
return encoder_outputs, outputs_mask