Hi, I am also getting the above mention error after first iteration. Below is the two layers Residual LSTM. The input size is different for each layer but the hidden size is the same (256). For the first layer the input size is 1088 for second it is 256. I think there is error in self.weight_ir
but I am not sure. Can you guide me?
class RLSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size, dropout=0.):
super(RLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.zeros(4 * hidden_size, input_size))
torch.nn.init.xavier_uniform_(self.weight_ih)
self.weight_hh = Parameter(torch.zeros(4 * hidden_size, hidden_size))
torch.nn.init.xavier_uniform_(self.weight_hh)
self.bias_ih = Parameter(torch.zeros(4 * hidden_size))
self.bias_hh = Parameter(torch.zeros(4 * hidden_size))
self.weight_ir = Parameter(torch.zeros(hidden_size, input_size))
torch.nn.init.xavier_uniform_(self.weight_ir)
@jit.script_method
def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
hx, cx = state
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
ry = torch.tanh(cy) #eqution 12 in the paper https://arxiv.org/pdf/1701.03360v3.pdf
#hy = outgate * torch.tanh(cy)
if self.input_size == self.hidden_size:
hy = outgate * (ry + input) #eqution 15 in the paper
else:
hy = outgate * (ry + torch.mm(input, self.weight_ir.t()))
return hy, (hy, cy)
class LSTMLayer(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super(LSTMLayer, self).__init__()
self.layer1 = RLSTMCell(input_size, hidden_size)
self.layer2 = RLSTMCell(hidden_size, hidden_size)
@jit.script_method
def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
inputs = input.unbind(0)
outputs = torch.jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
out, state = self.layer1(inputs[i], state)
out, state= self.layer2(state[0], state)
outputs += [out]
return torch.stack(outputs), state