I tried adding peephole connection for LSTM using torch.nn.Linear(), but it really is slow to calculate.
Is there any better way to make it more faster?
My current code is below.
Any help would be appreciated.
class Handmade_LSTMCell(torch.nn.Module):
def __init__(self, _input_size, _hidden_size):
super().__init__()
self.input_size = _input_size
self.hidden_size = _hidden_size
#input gate
self.linear_ii = torch.nn.Linear(_input_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_ii.weight, gain = 1.0)
torch.nn.init.zeros_(self.linear_ii.bias)
self.linear_hi = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_hi.weight, gain = 1.0)
torch.nn.init.zeros_(self.linear_hi.bias)
self.input_activation = torch.nn.Sigmoid()
#forget gate
self.linear_if = torch.nn.Linear(_input_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_if.weight, gain = 1.0)
torch.nn.init.zeros_(self.linear_if.bias)
self.linear_hf = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_hf.weight, gain = 1.0)
torch.nn.init.zeros_(self.linear_hf.bias)
self.forget_activation = torch.nn.Sigmoid()
#cell gate
self.linear_ic = torch.nn.Linear(_input_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_ic.weight, gain = 5.0/3.0)
torch.nn.init.zeros_(self.linear_ic.bias)
self.linear_hc = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_hc.weight, gain = 5.0/3.0)
torch.nn.init.zeros_(self.linear_hc.bias)
self.cell_activation = torch.nn.Tanh()
#output gate
self.linear_io = torch.nn.Linear(_input_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_io.weight, gain = 1.0)
torch.nn.init.zeros_(self.linear_io.bias)
self.linear_ho = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_ho.weight, gain = 1.0)
torch.nn.init.zeros_(self.linear_ho.bias)
self.output_activation = torch.nn.Sigmoid()
#peephole connection
self.linear_ci = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_ci.weight, gain = 1.0)
torch.nn.init.zeros_(self.linear_ci.bias)
self.linear_cf = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_cf.weight, gain = 1.0)
torch.nn.init.zeros_(self.linear_cf.bias)
self.linear_co = torch.nn.Linear(_hidden_size, _hidden_size, bias = True)
torch.nn.init.xavier_uniform_(self.linear_co.weight, gain = 1.0)
torch.nn.init.zeros_(self.linear_co.bias)
def forward(self, _input, _state = None):
is_batched = _input.dim() == 2
if is_batched:
_input = _input.unsqueeze(0)
#_inpt.size(0) = sequence length
if _state == None:
hx = torch.zeros(_input.size(0), self.hidden_size, dtype=_input.dtype, device=_input.device)
cx = torch.zeros(_input.size(0), self.hidden_size, dtype=_input.dtype, device=_input.device)
else:
hx, cx = _state
outputs = torch.empty(_input.size(1), _input.size(0), self.hidden_size, dtype=_input.dtype, device=_input.device)
for i in range(_input.size(1)):
input_gate = self.input_activation(self.linear_ii(_input[:,i,:]) + self.linear_hi(hx) + self.linear_ci(cx))
forget_gate = self.forget_activation(self.linear_if(_input[:,i,:]) + self.linear_hf(hx) + self.linear_cf(cx))
cell_gate = self.cell_activation(self.linear_ic(_input[:,i,:]) + self.linear_hc(hx))
cy = (forget_gate * cx) + (input_gate * cell_gate)
output_gate = self.output_activation(self.linear_io(_input[:,i,:]) + self.linear_ho(hx) + self.linear_co(cy))
hx = output_gate * torch.tanh(cy)
outputs[i] = hx
cx = cy
if is_batched:
return outputs.transpose(0, 1).contiguous().squeeze(0), (hx, cy)
else:
return outputs.transpose(0, 1).contiguous(), (hx.unsqueeze(0), cy.unsqueeze(0))