I am looking for a quick and easy way to implement recurrent dropout (Gal and Ghahramani, 2016) in Pytorch. Currently I just wrote a custom LSTM Cell myself. It looks like:
class LSTMCell(RNNCellBase):
def __init__(self, input_size, hidden_size, dropout=None):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.W_i = Parameter(torch.Tensor(hidden_size, input_size))
self.U_i = Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_i = Parameter(torch.Tensor(hidden_size))
self.W_f = Parameter(torch.Tensor(hidden_size, input_size))
self.U_f = Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_f = Parameter(torch.Tensor(hidden_size))
self.W_c = Parameter(torch.Tensor(hidden_size, input_size))
self.U_c = Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_c = Parameter(torch.Tensor(hidden_size))
self.W_o = Parameter(torch.Tensor(hidden_size, input_size))
self.U_o = Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_o = Parameter(torch.Tensor(hidden_size))
self._input_dropout_mask = self._h_dropout_mask = None
self.reset_parameters()
def reset_parameters(self):
init.orthogonal(self.W_i)
init.orthogonal(self.U_i)
init.orthogonal(self.W_f)
init.orthogonal(self.U_f)
init.orthogonal(self.W_o)
init.orthogonal(self.U_o)
self.b_f.data.fill_(1.)
def set_dropout_masks(self, batch_size):
if self.dropout:
if self.training:
self._input_dropout_mask = Variable(torch.bernoulli(
torch.Tensor(4, batch_size, self.input_size).fill_(1 - self.dropout)), requires_grad=False)
self._h_dropout_mask = Variable(torch.bernoulli(
torch.Tensor(4, batch_size, self.hidden_size).fill_(1 - self.dropout)), requires_grad=False)
if torch.cuda.is_available():
self._input_dropout_mask = self._input_dropout_mask.cuda()
self._h_dropout_mask = self._h_dropout_mask.cuda()
else:
self._input_dropout_mask = self._h_dropout_mask = [1. - self.dropout] * 4
else:
self._input_dropout_mask = self._h_dropout_mask = [1.] * 4
def forward(self, input, hidden_state):
h_tm1, c_tm1 = hidden_state
if self._input_dropout_mask is None:
self.set_dropout_masks(input.size(0))
xi_t = F.linear(input * self._input_dropout_mask[0], self.W_i, self.b_i)
xf_t = F.linear(input * self._input_dropout_mask[1], self.W_f, self.b_f)
xc_t = F.linear(input * self._input_dropout_mask[2], self.W_c, self.b_c)
xo_t = F.linear(input * self._input_dropout_mask[3], self.W_o, self.b_o)
i_t = F.sigmoid(xi_t + F.linear(h_tm1 * self._h_dropout_mask[0], self.U_i))
f_t = F.sigmoid(xf_t + F.linear(h_tm1 * self._h_dropout_mask[1], self.U_f))
c_t = f_t * c_tm1 + i_t * F.tanh(xc_t + F.linear(h_tm1 * self._h_dropout_mask[2], self.U_c))
o_t = F.sigmoid(xo_t + F.linear(h_tm1 * self._h_dropout_mask[3], self.U_o))
h_t = o_t * F.tanh(c_t)
return h_t, c_t
However, it’s about 2 times slower than the built-in LSTMCell
. How can I improve its efficiency? Or is there anyway to reuse the built-in LSTMCell
to implement it? Thanks in advance!