I wish to freeze the following model Recognizer_Net
which is structured as:
class LayerNorm_LSTM(nn.Module):
def __init__(self, in_dim=512, hidden_dim=512, bidirectional=False):
super(LayerNorm_LSTM, self).__init__()
self.layernorm = nn.LayerNorm(in_dim)
self.lstm = nn.LSTM(in_dim, hidden_dim, batch_first=True, bidirectional=bidirectional)
@torch.jit.export
def forward(
self, input_: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
input_ = self.layernorm(input_)
lstm_out, hidden = self.lstm(input_, hidden)
return lstm_out, hidden
class Recognizer_Net(nn.Module):
def __init__(self, input_dim, hidden_dim, num_hidden, BN_dim, mel_mean, mel_std, bidirectional=False):
super(MT_LSTM_BN_LayerNorm_Net, self).__init__()
self.linear_pre1 = nn.Linear(input_dim, hidden_dim)
self.linear_pre2 = nn.Linear(hidden_dim, hidden_dim)
self.lstm = nn.ModuleList([
LayerNorm_LSTM(in_dim=hidden_dim, hidden_dim=hidden_dim,
bidirectional=bidirectional)
for i in range(num_hidden)
])
self.BN_linear = nn.Linear(hidden_dim, BN_dim)
self.tanh = nn.Tanh()
self.num_hidden = num_hidden
self.hidden_dim = hidden_dim
self.mel_mean = mel_mean
self.mel_std = mel_std
@torch.jit.export
def init(self) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states = torch.zeros(self.num_hidden, 1, self.hidden_dim)
cell_states = torch.zeros(self.num_hidden, 1, self.hidden_dim)
return hidden_states, cell_states
@torch.jit.export
def forward(
self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
hidden_states, cell_states = hidden
x = (x - self.mel_mean) / self.mel_std
pre_linear1 = F.relu(self.linear_pre1(x))
lstm_out = F.relu(self.linear_pre2(pre_linear1))
for l, lstm_layer in enumerate(self.lstm):
lstm_out, (hidden_states[l:l+1], cell_states[l:l+1]) = lstm_layer(lstm_out, (hidden_states[l:l+1], cell_states[l:l+1]))
BN_out = self.tanh(self.BN_linear(lstm_out))
return BN_out, (hidden_states, cell_states)
I compile the JIT model as follows:
recognizer = Recognizer_Net(input_dim=80, hidden_dim=512, num_hidden=3, BN_dim=256, mel_mean=mel_mean, mel_std=mel_std, bidirectional=False).cpu()
recognizer.load_state_dict(checkpoint)
recognizer.eval()
net_jit = torch.jit.script(recognizer)
net_jit = torch.jit.freeze(net_jit, preserved_attrs=["init"])
Now when I compare the non-JIT model with this newly frozen JIT model, I find that the returned BN_out
is consistent, the returned cell_states
is consistent, BUT the returned hidden_states
is not consistent. Such as:
input = torch.zeros((1, 300, 80), dtype=torch.float)
recognizer_jit = torch.jit.load(frozen_jit_model_path, map_location='cpu')
recognizer_hidden_jit = recognizer_jit.init()
BN_out_jit, recognizer_hidden_jit = recognizer_jit.forward(input, recognizer_hidden_jit)
recognizer = Recognizer_Net(input_dim=80, hidden_dim=512, num_hidden=3, BN_dim=256, mel_mean=mel_mean, mel_std=mel_std, bidirectional=False).cpu()
recognizer.load_state_dict(checkpoint)
recognizer.eval()
recognizer_hidden = recognizer.init()
BN_out, recognizer_hidden = recognizer.forward(input, recognizer_hidden)
I get that BN_out_jit == BN_out
and recognizer_hidden_jit[1] == recognizer_hidden[1]
, but recognizer_hidden_jit[0] != recognizer_hidden[0]
.
Note that I do not get this problem if I do not use the Freezing API when compiling the JIT model. Is this a bug with the Freezing API? Or am I missing some steps in the freezing of the jit model?