Make a GRU that is interchangable with an LSTM

Hi folks,

I might’ve missed the point here, but i am experimenting with using a GRU or LSTM. It is kind of annoying in the downstream applications to need different code paths for if the state is a tuple (LSTM) or a tensor (GRU) (see this discussion) . I have started using the following class which wraps the call to the GRU to make it mimic the return signature of the LSTM (so-called GRU-mimic or GRUm).

I’m not sure if there is a better way of doing this. The code is in a colab here. It passes tests suggesting that the GRU is unmodified, but matches the signature of the LSTM.

It does incur some extra memory storing the (unused) cell states, but it is so much easier because which entries are relevant can be decided downstream.

Thanks,
Andy

class GRUm(torch.nn.GRU):
    """
    Define a GRU cell that mimics the signature of an LSTM (_GRUm_).

    This only requires a wrapper of the forward method to strip and then re-add a cell state. 
    """

    def forward(self, input, hxcx=None):

        if hxcx is not None:
            hx, cx = hxcx
        else:
            hx = None
            cx = None

        output, hx = super().forward(input, hx)

        if cx is None:
            cx = torch.zeros_like(hx).to(hx.device)

        return output, (hx, cx)

Although it probably doesn’t fit your use case: I often built “customizable” models where I can decide to use an LSTM or a GRU using an input parameter.

Here is an example; search for class RnnTextClassifier. Based on a simple string input parameter the classifier either uses an nn.RNN, nn.LSTM, or nn.GRU. The different outputs are he handled within the forward() method, most basically if the hidden state is a tensor (nn.RNN, nn.GRU) or a tuple (nn.LSTM).