Naive implementation of SRU

Hello,
I am trying to implement Simple Recurrent Unit (SRU). The core idea of SRU lies in Equation (3)-(7) and my naive implementation (i.e., without any optimization) for bi-SRU is below:

def SRUStep(ft, rt, xt_tilde, ctm1):
	"""
	"""
	tanh = nn.Tanh()
	c_t = ft * ctm1 + (1.0 - ft) * xt_tilde
	h_t = rt * tanh(c_t) + (1.0 - rt) * xt_tilde
	return h_t, c_t

class SRU(nn.Module):
	"""
	implementation of simple recurrent unit
	"""
	def __init__(self, args):
		"""
		"""
		super(SRU, self).__init__()
		self.n_in = args.dim_w
		self.n_out = args.dim_h
		self.W = Parameter(torch.FloatTensor(np.zeros((self.n_in, 3 * self.n_out))))
		self.bf = Parameter(torch.FloatTensor(np.zeros(self.n_out)))
		self.br = Parameter(torch.FloatTensor(np.zeros(self.n_out)))

		# recurrent dropout
		self.Dropout_recurrent = nn.Dropout(0.5)
		# non-recurrent dropout
		self.Dropout = nn.Dropout(0.3)

		self.sigmoid = nn.Sigmoid()
		self.tanh = nn.Tanh()
		self.bs = args.bs
		self.n_steps = args.sent_len 
		self.init_weight()

	def forward(self, x):
		"""
		x shape: (bs, max_len, dim_x)
		"""
		x = self.Dropout(x)
		# shape: (bs, max_len, 3 * dim_h)
		Wx = torch.bmm(x, self.W.repeat(self.bs, 1, 1))
		# using different dropout masks at different time steps
		Wx = self.Dropout_recurrent(Wx)
		x_tilde = Wx[:, :, :self.n_out]
		f = self.sigmoid(Wx[:, :, self.n_out:2*self.n_out] + self.bf)
		r = self.sigmoid(Wx[:, :, 2*self.n_out:] + self.br)
		ct = Variable(x.data.new(self.bs, self.n_out).zero_())
		H_fwd, H_bwd = [], []
		for i in range(self.n_steps):
			ft = f[:, i]
			rt = r[:, i]
			xt_tilde = x_tilde[:, i]
			ht, ct = SRUStep(ft, rt, xt_tilde, ct)
			H_fwd.append(ht)
		ct = Variable(x.data.new(self.bs, self.n_out).zero_())
		for i in range(self.n_steps-1, -1, -1):
			ft = f[:, i]
			rt = r[:, i]
			xt_tilde = x_tilde[:, i]
			ht, ct = SRUStep(ft, rt, xt_tilde, ct)
			H_bwd.append(ht)
		H_fwd = torch.stack(H_fwd).permute(1, 0, 2)
		H_bwd = torch.stack(H_bwd).permute(1, 0, 2)
		H = torch.cat([H_fwd, H_bwd], dim=2)
		return H

	def init_weight(self):
		"""
		"""
		val_range = (3.0 / self.n_in) ** 0.5
		# sample value from a uniform distribution
		self.W.data.uniform_(-val_range, val_range)

I wonder if my implementation is correct in Pytorch :thinking:

I also implement SRU in pytorch,I have two version, one is that use the author code by call the SRU, another is that implement it by formula, this is my demo, https://github.com/bamtercelboo/pytorch_SRU, II dont know that yours whether very correctly and even my code I also dont know, we can communicate each other, I have a question in cuda, SRU by formula is not speed up compare to the CPU, I don`t know the reason, you konw?