[GRU] Write custom cell

Hi !

After using the existing GRU implementation quite extensively, I am wondering whether or not the intermediate gates (especially the forget one) are relevant for my task, following the ideas of https://arxiv.org/abs/1603.09420.

However, when digging into the implementation of GRU (https://pytorch.org/docs/stable/_modules/torch/nn/modules/rnn.html#GRU) I am faced with the following :
return self._backend.GRUCell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, )

It seems that ._backend refers to THNN in my case, but I did not manage to find the corresponding source code, and am not sure I should be interfering with that…

Is there any way for me to write a custom GRUCell (e.g. MinimalGRUCell), that would replace _backend.GRUCell ?
If so, what are the conventions I should respect to ensure that the resulting cell can be used without causing any issues down the line (e.g. can I just write a Module using Functional functions, and replace it?)

Best regards,
Arnaud

1 Like