This is the architecture of a model that I have trained:
class LayerNorm_LSTM(nn.Module):
def __init__(self, in_dim, hidden_dim, bidirectional=False):
super(LayerNorm_LSTM, self).__init__()
self.layernorm = nn.LayerNorm(in_dim)
self.lstm = nn.LSTM(in_dim, hidden_dim, batch_first=True, bidirectional=bidirectional)
def forward(self, input_, hidden):
input_ = self.layernorm(input_)
lstm_out, hidden = self.lstm(input_, hidden)
return lstm_out, hidden
class Recognizer_Net(nn.Module):
def __init__(self, input_dim, hidden_dim, num_hidden, BN_dim, mel_mean, mel_std, bidirectional=False):
super(Recognizer_Net, self).__init__()
self.linear_pre1 = nn.Linear(input_dim, hidden_dim)
self.linear_pre2 = nn.Linear(hidden_dim, hidden_dim)
self.lstm = nn.ModuleList([
LayerNorm_LSTM(in_dim=hidden_dim, hidden_dim=hidden_dim,
bidirectional=bidirectional)
for i in range(num_hidden)
])
self.BN_linear = nn.Linear(hidden_dim, BN_dim)
self.tanh = nn.Tanh()
self.num_hidden = num_hidden
self.hidden_dim = hidden_dim
self.mel_mean = mel_mean
self.mel_std = mel_std
def init(self):
hidden_states = torch.zeros(self.num_hidden, 1, self.hidden_dim)
cell_states = torch.zeros(self.num_hidden, 1, self.hidden_dim)
return hidden_states, cell_states
def forward(self, x, hidden):
hidden_states, cell_states = hidden
x = (x - self.mel_mean) / self.mel_std
pre_linear1 = F.relu(self.linear_pre1(x))
lstm_out = F.relu(self.linear_pre2(pre_linear1))
for l, lstm_layer in enumerate(self.lstm):
lstm_out, (hidden_states[l:l+1], cell_states[l:l+1]) = lstm_layer(lstm_out, (hidden_states[l:l+1], cell_states[l:l+1]))
BN_out = self.tanh(self.BN_linear(lstm_out))
return BN_out, (hidden_states, cell_states)
And I have the following code to test inference:
input = torch.ones((1, 300, 80), dtype=torch.float)
recognizer = Recognizer_Net(input_dim=80, hidden_dim=512, num_hidden=3, BN_dim=256, mel_mean=mel_mean, mel_std=mel_std, bidirectional=False)
recognizer.eval()
recognizer_hidden = recognizer.init()
BN_out, recognizer_hidden = recognizer.forward(input, recognizer_hidden)
recognizer_hidden = recognizer.init()
BN_out_test, recognizer_hidden = recognizer.forward(input[:,0:16,:], recognizer_hidden)
recognizer_hidden = recognizer.init()
BN_out_test2, recognizer_hidden = recognizer.forward(input[:,0:15,:], recognizer_hidden)
I find that BN_out[:,0:16,:] == BN_out_test
, but then BN_out[:,0:15,:] != BN_out_test2
. Here are the printed results accordingly:
(Pdb) p BN_out[:,:16,:]
tensor([[[ 0.12310354, -0.06452109, 0.04474466, ..., 0.06025446,
-0.02436436, 0.03219125],
[ 0.16362280, -0.08077041, 0.07871316, ..., 0.09137958,
-0.04021315, 0.03032400],
[ 0.18024816, -0.08485510, 0.09805698, ..., 0.10307459,
-0.04448149, 0.02867011],
...,
[ 0.18468492, -0.06737500, 0.14588971, ..., 0.09201521,
-0.00267485, 0.00157470],
[ 0.18401137, -0.06671639, 0.14601991, ..., 0.09149401,
-0.00209353, -0.00024496],
[ 0.18337716, -0.06622121, 0.14607997, ..., 0.09106418,
-0.00173651, -0.00173810]]])
(Pdb) p BN_out_test
tensor([[[ 0.12310354, -0.06452109, 0.04474466, ..., 0.06025446,
-0.02436436, 0.03219125],
[ 0.16362280, -0.08077041, 0.07871316, ..., 0.09137958,
-0.04021315, 0.03032400],
[ 0.18024816, -0.08485510, 0.09805698, ..., 0.10307459,
-0.04448149, 0.02867011],
...,
[ 0.18468492, -0.06737500, 0.14588971, ..., 0.09201521,
-0.00267485, 0.00157470],
[ 0.18401137, -0.06671639, 0.14601991, ..., 0.09149401,
-0.00209353, -0.00024496],
[ 0.18337716, -0.06622121, 0.14607997, ..., 0.09106418,
-0.00173651, -0.00173810]]])
(Pdb) p BN_out_test2
tensor([[[ 0.12310345, -0.06452110, 0.04474467, ..., 0.06025450,
-0.02436449, 0.03219126],
[ 0.16362284, -0.08077045, 0.07871307, ..., 0.09137955,
-0.04021320, 0.03032403],
[ 0.18024816, -0.08485514, 0.09805696, ..., 0.10307456,
-0.04448152, 0.02867013],
...,
[ 0.18537879, -0.06823060, 0.14564055, ..., 0.09265027,
-0.00357649, 0.00375958],
[ 0.18468492, -0.06737494, 0.14588960, ..., 0.09201526,
-0.00267488, 0.00157468],
[ 0.18401131, -0.06671640, 0.14601973, ..., 0.09149403,
-0.00209353, -0.00024493]]])
As you can see, there are slight differences in the lower-precision decimal points. I am confused on why this is happening… Also I’m confused on why taking the view of an input larger than [:,0:15,:]
keeps the same precision. Does anyone know why this is happening? Is this a bug in PyTorch?