Hi folks,
I might’ve missed the point here, but i am experimenting with using a GRU or LSTM. It is kind of annoying in the downstream applications to need different code paths for if the state is a tuple (LSTM) or a tensor (GRU) (see this discussion) . I have started using the following class which wraps the call to the GRU to make it mimic the return signature of the LSTM (so-called GRU-mimic or GRUm).
I’m not sure if there is a better way of doing this. The code is in a colab here. It passes tests suggesting that the GRU is unmodified, but matches the signature of the LSTM.
It does incur some extra memory storing the (unused) cell states, but it is so much easier because which entries are relevant can be decided downstream.
Thanks,
Andy
class GRUm(torch.nn.GRU):
"""
Define a GRU cell that mimics the signature of an LSTM (_GRUm_).
This only requires a wrapper of the forward method to strip and then re-add a cell state.
"""
def forward(self, input, hxcx=None):
if hxcx is not None:
hx, cx = hxcx
else:
hx = None
cx = None
output, hx = super().forward(input, hx)
if cx is None:
cx = torch.zeros_like(hx).to(hx.device)
return output, (hx, cx)