GRU implementation by myself

Hi everyone,
I need to implement the Gated Recurrent Unit by myself because I will need to modify some terms in its equations. But, firstly I would like to implement the original GRU to be sure that everything works correctly. This is my implementation, I did a simple test to compare its results with the ones of the Pytorch implementation, but they are different.
Where am I going wrong?

``````import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

import math

class GRU(nn.Module):
def __init__(self, input_size, hidden_size):
super(GRU, self).__init__()

self.Wx = Parameter(torch.randn(input_size, hidden_size * 3)).div(math.sqrt(input_size))
self.bx = Parameter(torch.zeros(hidden_size * 3))
self.Wh = Parameter(torch.randn(hidden_size, hidden_size * 3)).div(math.sqrt(hidden_size))
self.bh = Parameter(torch.zeros(hidden_size * 3))

def step(self, x, prev_h):
batch_size, hidden_size = prev_h.shape

ax = x.mm(self.Wx) + self.bx
ax = ax.reshape(batch_size, 3, hidden_size)
ah = prev_h.mm(self.Wh) + self.bh
ah = ah.reshape(batch_size, 3, hidden_size)

r = torch.sigmoid(ax[:, 0, :] + ah[:, 0, :])
z = torch.sigmoid(ax[:, 1, :] + ah[:, 1, :])
n = torch.tanh(ax[:, 2, :] + r * ah[:, 2, :])

next_h = (1 - z) * n + z * prev_h
return next_h

def forward(self, x, h0):
batch_size, steps, num_features = x.shape
hidden_size = h0.shape[1]
hts = torch.zeros(batch_size, steps, hidden_size).to(dtype=x.dtype)

h_t = h0
for step_t in range(steps):
x_t = x[:, step_t]
h_t = self.step(x_t, h_t)
hts[:, step_t] = h_t

return hts
``````

Thank you

Not really helpful, but maybe you can double-check that `reshape()` is doing the right thing.