Why?
This is my own code, I only give the forward, Because others are the same as pytorch’s code
def forward(self, input, hx=None, cx=None):
is_packed = isinstance(input, PackedSequence)
if is_packed:
input, batch_sizes = input
max_batch_size = int(batch_sizes[0])
else:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = input.new_zeros(max_batch_size, self.num_layers * num_directions, self.hidden_size,
requires_grad=False)
if cx is None:
num_directions = 2 if self.bidirectional else 1
cx = input.new_zeros(max_batch_size, self.num_layers * num_directions, self.hidden_size,
requires_grad=False)
has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs
if has_flat_weights:
first_data = next(self.parameters()).data
assert first_data.storage().size() == self._param_buf_size
flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size]))
else:
flat_weight = None
self.check_forward_args(input, hx, batch_sizes)
output = input
for layer in range(self.num_layers):
h_input = hx[:, layer, :]
c_input = cx[:, layer, :]
input = output
output = []
for squ in range(input.size(1)):
h_out, c_out = self.LSTMCell(input[:, squ, :], h_input, c_input, layer)
output.append(h_out)
h_input, c_input = h_out, c_out
output = torch.cat(tuple([x.unsqueeze(1) for x in output]), 1)
return output
def LSTMCell(self, input, h, c, layer):
weight_i = getattr(self, 'weight_ih_l{}'.format(layer))
weight_h = getattr(self, 'weight_hh_l{}'.format(layer))
bias = getattr(self, 'bias_l{}'.format(layer))
bias = bias.unsqueeze(1)
#bias_h = getattr(self, 'bias_hh_l{}'.format(layer))
input_var = input.unsqueeze(2)
h_var = h.unsqueeze(2)
weight = torch.cat((weight_i, weight_h), 1)
stack = torch.cat(tuple([(weight.mm(torch.cat((input_var, h_var), 1)[i]) + bias).unsqueeze(0) for i in range(input_var.size(0))]), 0)
stack = stack.squeeze(2)
ll = weight_h.size(0) // 4
gt = stack[:, :ll]
ifo = stack[:, ll:]
gt = F.tanh(gt)
ifo = F.sigmoid(ifo)
it = ifo[:, :ll]
ft = ifo[:, ll: (2 * ll)]
ot = ifo[:, (2 * ll):]
c = gt * it + c * ft
h = F.tanh(c) * ot
return h, c