Hi
I am trying to implement a custom bidirectional GRU network but I am unsure how to exactly deal with the input so that I get the correct output for both directions of the network. My implementation is very similar to the bidirectional LSTM implementation found here: https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
my implementation is as follows:
class GRUCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(GRUCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
def forward(self, inp, hidden):
gate_input = torch.mm(inp, self.weight_ih.t()) + self.bias_ih
gate_hidden = torch.mm(hidden, self.weight_hh.t()) + self.bias_hh
i_reset, i_input, i_new = gate_input.chunk(3, 1)
h_reset, h_input, h_new = gate_hidden.chunk(3, 1)
reset_gate = torch.sigmoid(i_reset + h_reset)
input_gate = torch.sigmoid(i_input + h_input)
new_gate = torch.tanh(i_new + reset_gate * h_new)
next_hidden = new_gate + input_gate * (hidden - new_gate)
return next_hidden
class GRULayer(nn.Module):
def __init__(self, cell, *cell_args):
super(GRULayer, self).__init__()
self.cell = cell(*cell_args)
def forward(self, inp, state):
inputs = inp.unbind(0)
outputs = []
for i in range(len(inputs)):
state = self.cell(inputs[i], state)
outputs += [state]
return torch.stack(outputs), state
class GRUReverseLayer(nn.Module):
def __init__(self, cell, *cell_args):
super(GRUReverseLayer, self).__init__()
self.cell = cell(*cell_args)
def forward(self, inp, state):
inputs = inp.unbind(0)
outputs = []
l_inputs = len(inputs)
for i in range(l_inputs):
j = l_inputs - i - 1
state = self.cell(inputs[j], state)
outputs = [state] + outputs
return torch.stack(outputs), state
class BidirGRULayer(nn.Module):
def __init__(self, cell, *cell_args):
super(BidirGRULayer, self).__init__()
self.directions = nn.ModuleList([
GRULayer(cell, *cell_args),
GRUReverseLayer(cell, *cell_args)
])
def forward(self, inp, states):
outputs = []
output_states = []
for i, direction in enumerate(self.directions):
state = states[i]
out, out_state = direction(inp, state)
outputs += [out]
output_states += [out_state]
return torch.cat(outputs, -1), output_states
Lets say I want to run the network on the following input with 2 batches and three timesteps:
inp = torch.tensor([[1,2,3], [4,5,0]])
inp_lengths = torch.tensor([3,2], dtype=torch.int16)
embedding = nn.Embedding(10, 3, padding_idx=0)
embedded = embedding(inp.t())
initial_hidden = torch.zeros(2,2,3)
The second batch is padded with a zero so that both sequences have the same length and an embedding of the input is obtained. I can then create and call my network with the input as follows:
gru = BidirGRULayer(GRUCell, 3, 3)
out, h = gru(embedded, initial_hidden)
Now my question is if this will get me the desired output. I know that I can get the correct final hidden states of the output for the forward pass by using the length of the unpadded sequences in the input so that I do not use the hidden state of the padding value for the second sequence but how do I handle this for the backwards pass since the padding value is fed to the network as the first timestep?
Lets say the output is as follows:
tensor([[[-0.6668, -0.1728, -0.2585, -0.9336, 0.9793, 0.8094],
[-0.3246, -0.6246, 0.3892, -0.9491, 0.9982, 0.7115]],
[[-0.6869, -0.9550, -0.1347, -0.9154, 0.9353, 0.9857],
[-0.8382, -0.6098, 0.2500, -0.9753, 0.9982, 0.9772]],
[[-0.8331, -0.9331, -0.1964, -0.3474, 0.9081, -0.4695],
[-0.8896, -0.6312, -0.0221, -0.0326, 0.7980, -0.5642]]])
The final hidden states for the forward pass can be correctly identified as [[-0.8331, -0.9331, -0.1964],[[-0.8382, -0.6098, 0.2500]] by taking the input lengths into account (3 for the first sequence, 2 for the second sequence). However, for the backwards pass the final hidden states would be [[-0.9336, 0.9793, 0.8094],[-0.9491, 0.9982, 0.7115]]. But I am unsure whether this is correct since the backwards pass sees the padding value (0) first.
Is it better to use the GRULayer class (forward pass layer) twice and make a reversed copy of the input such that one layer gets the embedding of [[1,2,3],[4,5,0]] while the other layer gets an embedding of [[3,2,1],[5,4,0]]?