I am trying to implement a custom bidirectional GRU network but I am unsure how to exactly deal with the input so that I get the correct output for both directions of the network. My implementation is very similar to the bidirectional LSTM implementation found here: https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
my implementation is as follows:
class GRUCell(nn.Module): def __init__(self, input_size, hidden_size): super(GRUCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size)) self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size)) self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size)) self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size)) def forward(self, inp, hidden): gate_input = torch.mm(inp, self.weight_ih.t()) + self.bias_ih gate_hidden = torch.mm(hidden, self.weight_hh.t()) + self.bias_hh i_reset, i_input, i_new = gate_input.chunk(3, 1) h_reset, h_input, h_new = gate_hidden.chunk(3, 1) reset_gate = torch.sigmoid(i_reset + h_reset) input_gate = torch.sigmoid(i_input + h_input) new_gate = torch.tanh(i_new + reset_gate * h_new) next_hidden = new_gate + input_gate * (hidden - new_gate) return next_hidden class GRULayer(nn.Module): def __init__(self, cell, *cell_args): super(GRULayer, self).__init__() self.cell = cell(*cell_args) def forward(self, inp, state): inputs = inp.unbind(0) outputs =  for i in range(len(inputs)): state = self.cell(inputs[i], state) outputs += [state] return torch.stack(outputs), state class GRUReverseLayer(nn.Module): def __init__(self, cell, *cell_args): super(GRUReverseLayer, self).__init__() self.cell = cell(*cell_args) def forward(self, inp, state): inputs = inp.unbind(0) outputs =  l_inputs = len(inputs) for i in range(l_inputs): j = l_inputs - i - 1 state = self.cell(inputs[j], state) outputs = [state] + outputs return torch.stack(outputs), state class BidirGRULayer(nn.Module): def __init__(self, cell, *cell_args): super(BidirGRULayer, self).__init__() self.directions = nn.ModuleList([ GRULayer(cell, *cell_args), GRUReverseLayer(cell, *cell_args) ]) def forward(self, inp, states): outputs =  output_states =  for i, direction in enumerate(self.directions): state = states[i] out, out_state = direction(inp, state) outputs += [out] output_states += [out_state] return torch.cat(outputs, -1), output_states
Lets say I want to run the network on the following input with 2 batches and three timesteps:
inp = torch.tensor([[1,2,3], [4,5,0]]) inp_lengths = torch.tensor([3,2], dtype=torch.int16) embedding = nn.Embedding(10, 3, padding_idx=0) embedded = embedding(inp.t()) initial_hidden = torch.zeros(2,2,3)
The second batch is padded with a zero so that both sequences have the same length and an embedding of the input is obtained. I can then create and call my network with the input as follows:
gru = BidirGRULayer(GRUCell, 3, 3) out, h = gru(embedded, initial_hidden)
Now my question is if this will get me the desired output. I know that I can get the correct final hidden states of the output for the forward pass by using the length of the unpadded sequences in the input so that I do not use the hidden state of the padding value for the second sequence but how do I handle this for the backwards pass since the padding value is fed to the network as the first timestep?
Lets say the output is as follows:
tensor([[[-0.6668, -0.1728, -0.2585, -0.9336, 0.9793, 0.8094], [-0.3246, -0.6246, 0.3892, -0.9491, 0.9982, 0.7115]], [[-0.6869, -0.9550, -0.1347, -0.9154, 0.9353, 0.9857], [-0.8382, -0.6098, 0.2500, -0.9753, 0.9982, 0.9772]], [[-0.8331, -0.9331, -0.1964, -0.3474, 0.9081, -0.4695], [-0.8896, -0.6312, -0.0221, -0.0326, 0.7980, -0.5642]]])
The final hidden states for the forward pass can be correctly identified as [[-0.8331, -0.9331, -0.1964],[[-0.8382, -0.6098, 0.2500]] by taking the input lengths into account (3 for the first sequence, 2 for the second sequence). However, for the backwards pass the final hidden states would be [[-0.9336, 0.9793, 0.8094],[-0.9491, 0.9982, 0.7115]]. But I am unsure whether this is correct since the backwards pass sees the padding value (0) first.
Is it better to use the GRULayer class (forward pass layer) twice and make a reversed copy of the input such that one layer gets the embedding of [[1,2,3],[4,5,0]] while the other layer gets an embedding of [[3,2,1],[5,4,0]]?