Variable length input in custom bidirectional GRU

Hi

I am trying to implement a custom bidirectional GRU network but I am unsure how to exactly deal with the input so that I get the correct output for both directions of the network. My implementation is very similar to the bidirectional LSTM implementation found here: https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py

my implementation is as follows:

class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
        self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
        self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))

    def forward(self, inp, hidden):
        gate_input = torch.mm(inp, self.weight_ih.t()) + self.bias_ih
        gate_hidden = torch.mm(hidden, self.weight_hh.t()) + self.bias_hh
        i_reset, i_input, i_new = gate_input.chunk(3, 1)
        h_reset, h_input, h_new = gate_hidden.chunk(3, 1)
        reset_gate = torch.sigmoid(i_reset + h_reset)
        input_gate = torch.sigmoid(i_input + h_input)
        new_gate = torch.tanh(i_new + reset_gate * h_new)
        next_hidden = new_gate + input_gate * (hidden - new_gate)
        return next_hidden


class GRULayer(nn.Module):
    def __init__(self, cell, *cell_args):
        super(GRULayer, self).__init__()
        self.cell = cell(*cell_args)
    
    def forward(self, inp, state):
        inputs = inp.unbind(0)
        outputs = []
        for i in range(len(inputs)):
            state = self.cell(inputs[i], state)
            outputs += [state]
        return torch.stack(outputs), state


class GRUReverseLayer(nn.Module):
    def __init__(self, cell, *cell_args):
        super(GRUReverseLayer, self).__init__()
        self.cell = cell(*cell_args)
    
    def forward(self, inp, state):
        inputs = inp.unbind(0)
        outputs = []
        l_inputs = len(inputs)
        for i in range(l_inputs):
            j = l_inputs - i - 1
            state = self.cell(inputs[j], state)
            outputs = [state] + outputs
        return torch.stack(outputs), state


class BidirGRULayer(nn.Module):
    def __init__(self, cell, *cell_args):
        super(BidirGRULayer, self).__init__()
        self.directions = nn.ModuleList([
            GRULayer(cell, *cell_args),
            GRUReverseLayer(cell, *cell_args)
        ])
    
    def forward(self, inp, states):
        outputs = []
        output_states = []
        for i, direction in enumerate(self.directions):
            state = states[i]
            out, out_state = direction(inp, state)
            outputs += [out]
            output_states += [out_state]
        return torch.cat(outputs, -1), output_states

Lets say I want to run the network on the following input with 2 batches and three timesteps:

inp = torch.tensor([[1,2,3], [4,5,0]])
inp_lengths = torch.tensor([3,2], dtype=torch.int16)
embedding = nn.Embedding(10, 3, padding_idx=0)
embedded = embedding(inp.t())
initial_hidden = torch.zeros(2,2,3)

The second batch is padded with a zero so that both sequences have the same length and an embedding of the input is obtained. I can then create and call my network with the input as follows:

gru = BidirGRULayer(GRUCell, 3, 3)
out, h = gru(embedded, initial_hidden)

Now my question is if this will get me the desired output. I know that I can get the correct final hidden states of the output for the forward pass by using the length of the unpadded sequences in the input so that I do not use the hidden state of the padding value for the second sequence but how do I handle this for the backwards pass since the padding value is fed to the network as the first timestep?

Lets say the output is as follows:

tensor([[[-0.6668, -0.1728, -0.2585, -0.9336,  0.9793,  0.8094],
         [-0.3246, -0.6246,  0.3892, -0.9491,  0.9982,  0.7115]],

        [[-0.6869, -0.9550, -0.1347, -0.9154,  0.9353,  0.9857],
         [-0.8382, -0.6098,  0.2500, -0.9753,  0.9982,  0.9772]],

        [[-0.8331, -0.9331, -0.1964, -0.3474,  0.9081, -0.4695],
         [-0.8896, -0.6312, -0.0221, -0.0326,  0.7980, -0.5642]]])

The final hidden states for the forward pass can be correctly identified as [[-0.8331, -0.9331, -0.1964],[[-0.8382, -0.6098, 0.2500]] by taking the input lengths into account (3 for the first sequence, 2 for the second sequence). However, for the backwards pass the final hidden states would be [[-0.9336, 0.9793, 0.8094],[-0.9491, 0.9982, 0.7115]]. But I am unsure whether this is correct since the backwards pass sees the padding value (0) first.

Is it better to use the GRULayer class (forward pass layer) twice and make a reversed copy of the input such that one layer gets the embedding of [[1,2,3],[4,5,0]] while the other layer gets an embedding of [[3,2,1],[5,4,0]]?

1 Like