Jit trace failure (First diverging operator Node diff)

I have a problem with torch.jit.trace execution failure.

First diverging operator:
Node diff:
- %transformer : torch.torch.nn.modules.transformer.Transformer = prim::GetAttrname=“transformer”
+ %transformer : torch.torch.nn.modules.transformer.___torch_mangle_297.Transformer = prim::GetAttrname=“transformer”
? ++++++++++++++++++++

encoding:utf-8

import math

import torch
import torch.nn as nn
from tokenizers import Tokenizer
from pathlib import Path
from torch.utils.mobile_optimizer import optimize_for_mobile

device = torch.device(‘cpu’)

max_length = 72

print(torch.version)

class PositionalEncoding(nn.Module):

def __init__(self, d_model, dropout, max_len=5000):
    super(PositionalEncoding, self).__init__()
    self.dropout = nn.Dropout(p=dropout)

    pe = torch.zeros(max_len, d_model).to(device)
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    self.register_buffer("pe", pe)

def forward(self, x):
    x = x + self.pe[:, : x.size(1)].requires_grad_(False)
    return self.dropout(x)

class TranslationModel(nn.Module):

def __init__(self, d_model, src_vocab, tgt_vocab, dropout=0.1):
    super(TranslationModel, self).__init__()

    self.src_embedding = nn.Embedding(len(src_vocab), d_model, padding_idx=2)
    self.tgt_embedding = nn.Embedding(len(tgt_vocab), d_model, padding_idx=2)
    self.positional_encoding = PositionalEncoding(d_model, dropout, max_len=max_length)
    self.transformer = nn.Transformer(d_model, dropout=dropout, batch_first=True)
    self.predictor = nn.Linear(d_model, len(tgt_vocab))

def forward(self, src, tgt):
    tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size()[-1])
    src_key_padding_mask = TranslationModel.get_key_padding_mask(src)
    tgt_key_padding_mask = TranslationModel.get_key_padding_mask(tgt)
    src = self.src_embedding(src)
    tgt = self.tgt_embedding(tgt)
    src = self.positional_encoding(src)
    tgt = self.positional_encoding(tgt)

    out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask)

    return out

@staticmethod
def get_key_padding_mask(tokens):
    return tokens == 2

加载基础的分词器模型,使用的是基础的bert模型。uncased意思是不区分大小写

tokenizer = Tokenizer.from_pretrained(“bert-base-uncased”)

def en_tokenizer(line):
return tokenizer.encode(line, add_special_tokens=False).tokens

work_dir = Path(“./dataset”)
en_vocab = None
zh_vocab = None
en_vocab_file = work_dir / “vocab_en.pt”
zh_vocab_file = work_dir / “vocab_zh.pt”
en_vocab = torch.load(en_vocab_file, map_location=“cpu”)
zh_vocab = torch.load(zh_vocab_file, map_location=“cpu”)

model_dir = Path(“./drive/MyDrive/model/transformer_checkpoints”)
model_checkpoint = “model_475000.pt”

model = torch.load(model_dir / model_checkpoint, map_location=torch.device(‘cpu’))
model = model.to(device)

model = model.eval()

src = torch.tensor([0, 163, 1]).unsqueeze(0).to(device)
tgt = torch.tensor([[0]]).to(device)

quantized_model = torch.quantization.convert(model)

scripted_module = torch.jit.trace(quantized_model, (tgt, src))
opt_module = optimize_for_mobile(scripted_module)
opt_module._save_for_lite_interpreter(“model_quantized.ptl”)