Naive implementation of SRU


#1

Hello,
I am trying to implement Simple Recurrent Unit (SRU). The core idea of SRU lies in Equation (3)-(7) and my naive implementation (i.e., without any optimization) for bi-SRU is below:

def SRUStep(ft, rt, xt_tilde, ctm1):
	"""
	"""
	tanh = nn.Tanh()
	c_t = ft * ctm1 + (1.0 - ft) * xt_tilde
	h_t = rt * tanh(c_t) + (1.0 - rt) * xt_tilde
	return h_t, c_t

class SRU(nn.Module):
	"""
	implementation of simple recurrent unit
	"""
	def __init__(self, args):
		"""
		"""
		super(SRU, self).__init__()
		self.n_in = args.dim_w
		self.n_out = args.dim_h
		self.W = Parameter(torch.FloatTensor(np.zeros((self.n_in, 3 * self.n_out))))
		self.bf = Parameter(torch.FloatTensor(np.zeros(self.n_out)))
		self.br = Parameter(torch.FloatTensor(np.zeros(self.n_out)))

		# recurrent dropout
		self.Dropout_recurrent = nn.Dropout(0.5)
		# non-recurrent dropout
		self.Dropout = nn.Dropout(0.3)

		self.sigmoid = nn.Sigmoid()
		self.tanh = nn.Tanh()
		self.bs = args.bs
		self.n_steps = args.sent_len 
		self.init_weight()

	def forward(self, x):
		"""
		x shape: (bs, max_len, dim_x)
		"""
		x = self.Dropout(x)
		# shape: (bs, max_len, 3 * dim_h)
		Wx = torch.bmm(x, self.W.repeat(self.bs, 1, 1))
		# using different dropout masks at different time steps
		Wx = self.Dropout_recurrent(Wx)
		x_tilde = Wx[:, :, :self.n_out]
		f = self.sigmoid(Wx[:, :, self.n_out:2*self.n_out] + self.bf)
		r = self.sigmoid(Wx[:, :, 2*self.n_out:] + self.br)
		ct = Variable(x.data.new(self.bs, self.n_out).zero_())
		H_fwd, H_bwd = [], []
		for i in range(self.n_steps):
			ft = f[:, i]
			rt = r[:, i]
			xt_tilde = x_tilde[:, i]
			ht, ct = SRUStep(ft, rt, xt_tilde, ct)
			H_fwd.append(ht)
		ct = Variable(x.data.new(self.bs, self.n_out).zero_())
		for i in range(self.n_steps-1, -1, -1):
			ft = f[:, i]
			rt = r[:, i]
			xt_tilde = x_tilde[:, i]
			ht, ct = SRUStep(ft, rt, xt_tilde, ct)
			H_bwd.append(ht)
		H_fwd = torch.stack(H_fwd).permute(1, 0, 2)
		H_bwd = torch.stack(H_bwd).permute(1, 0, 2)
		H = torch.cat([H_fwd, H_bwd], dim=2)
		return H

	def init_weight(self):
		"""
		"""
		val_range = (3.0 / self.n_in) ** 0.5
		# sample value from a uniform distribution
		self.W.data.uniform_(-val_range, val_range)

I wonder if my implementation is correct in Pytorch :thinking:


(bamtercelboo) #2

I also implement SRU in pytorch´╝îI have two version, one is that use the author code by call the SRU, another is that implement it by formula, this is my demo, https://github.com/bamtercelboo/pytorch_SRU, II dont know that yours whether very correctly and even my code I also dont know, we can communicate each other, I have a question in cuda, SRU by formula is not speed up compare to the CPU, I don`t know the reason, you konw?