Hello,
I am trying to implement Simple Recurrent Unit (SRU). The core idea of SRU lies in Equation (3)-(7) and my naive implementation (i.e., without any optimization) for bi-SRU is below:
def SRUStep(ft, rt, xt_tilde, ctm1):
"""
"""
tanh = nn.Tanh()
c_t = ft * ctm1 + (1.0 - ft) * xt_tilde
h_t = rt * tanh(c_t) + (1.0 - rt) * xt_tilde
return h_t, c_t
class SRU(nn.Module):
"""
implementation of simple recurrent unit
"""
def __init__(self, args):
"""
"""
super(SRU, self).__init__()
self.n_in = args.dim_w
self.n_out = args.dim_h
self.W = Parameter(torch.FloatTensor(np.zeros((self.n_in, 3 * self.n_out))))
self.bf = Parameter(torch.FloatTensor(np.zeros(self.n_out)))
self.br = Parameter(torch.FloatTensor(np.zeros(self.n_out)))
# recurrent dropout
self.Dropout_recurrent = nn.Dropout(0.5)
# non-recurrent dropout
self.Dropout = nn.Dropout(0.3)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.bs = args.bs
self.n_steps = args.sent_len
self.init_weight()
def forward(self, x):
"""
x shape: (bs, max_len, dim_x)
"""
x = self.Dropout(x)
# shape: (bs, max_len, 3 * dim_h)
Wx = torch.bmm(x, self.W.repeat(self.bs, 1, 1))
# using different dropout masks at different time steps
Wx = self.Dropout_recurrent(Wx)
x_tilde = Wx[:, :, :self.n_out]
f = self.sigmoid(Wx[:, :, self.n_out:2*self.n_out] + self.bf)
r = self.sigmoid(Wx[:, :, 2*self.n_out:] + self.br)
ct = Variable(x.data.new(self.bs, self.n_out).zero_())
H_fwd, H_bwd = [], []
for i in range(self.n_steps):
ft = f[:, i]
rt = r[:, i]
xt_tilde = x_tilde[:, i]
ht, ct = SRUStep(ft, rt, xt_tilde, ct)
H_fwd.append(ht)
ct = Variable(x.data.new(self.bs, self.n_out).zero_())
for i in range(self.n_steps-1, -1, -1):
ft = f[:, i]
rt = r[:, i]
xt_tilde = x_tilde[:, i]
ht, ct = SRUStep(ft, rt, xt_tilde, ct)
H_bwd.append(ht)
H_fwd = torch.stack(H_fwd).permute(1, 0, 2)
H_bwd = torch.stack(H_bwd).permute(1, 0, 2)
H = torch.cat([H_fwd, H_bwd], dim=2)
return H
def init_weight(self):
"""
"""
val_range = (3.0 / self.n_in) ** 0.5
# sample value from a uniform distribution
self.W.data.uniform_(-val_range, val_range)
I wonder if my implementation is correct in Pytorch