GRU implementation by myself

Hi everyone,
I need to implement the Gated Recurrent Unit by myself because I will need to modify some terms in its equations. But, firstly I would like to implement the original GRU to be sure that everything works correctly. This is my implementation, I did a simple test to compare its results with the ones of the Pytorch implementation, but they are different.
Where am I going wrong?

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

import math

class GRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GRU, self).__init__()

        self.Wx = Parameter(torch.randn(input_size, hidden_size * 3)).div(math.sqrt(input_size))
        self.bx = Parameter(torch.zeros(hidden_size * 3))
        self.Wh = Parameter(torch.randn(hidden_size, hidden_size * 3)).div(math.sqrt(hidden_size))
        self.bh = Parameter(torch.zeros(hidden_size * 3))

    def step(self, x, prev_h):
        batch_size, hidden_size = prev_h.shape

        ax = x.mm(self.Wx) + self.bx
        ax = ax.reshape(batch_size, 3, hidden_size)
        ah = prev_h.mm(self.Wh) + self.bh
        ah = ah.reshape(batch_size, 3, hidden_size)

        r = torch.sigmoid(ax[:, 0, :] + ah[:, 0, :])
        z = torch.sigmoid(ax[:, 1, :] + ah[:, 1, :])
        n = torch.tanh(ax[:, 2, :] + r * ah[:, 2, :])

        next_h = (1 - z) * n + z * prev_h
        return next_h

    def forward(self, x, h0):
        batch_size, steps, num_features = x.shape
        hidden_size = h0.shape[1]
        hts = torch.zeros(batch_size, steps, hidden_size).to(dtype=x.dtype)

        h_t = h0
        for step_t in range(steps):
            x_t = x[:, step_t]
            h_t = self.step(x_t, h_t)
            hts[:, step_t] = h_t

        return hts

Thank you :slight_smile:

Not really helpful, but maybe you can double-check that reshape() is doing the right thing.