Implementing recurrent dropout

I am looking for a quick and easy way to implement recurrent dropout (Gal and Ghahramani, 2016) in Pytorch. Currently I just wrote a custom LSTM Cell myself. It looks like:

class LSTMCell(RNNCellBase):

    def __init__(self, input_size, hidden_size, dropout=None):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout

        self.W_i = Parameter(torch.Tensor(hidden_size, input_size))
        self.U_i = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_i = Parameter(torch.Tensor(hidden_size))

        self.W_f = Parameter(torch.Tensor(hidden_size, input_size))
        self.U_f = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_f = Parameter(torch.Tensor(hidden_size))

        self.W_c = Parameter(torch.Tensor(hidden_size, input_size))
        self.U_c = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_c = Parameter(torch.Tensor(hidden_size))

        self.W_o = Parameter(torch.Tensor(hidden_size, input_size))
        self.U_o = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_o = Parameter(torch.Tensor(hidden_size))

        self._input_dropout_mask = self._h_dropout_mask = None

        self.reset_parameters()

    def reset_parameters(self):
        init.orthogonal(self.W_i)
        init.orthogonal(self.U_i)
        init.orthogonal(self.W_f)
        init.orthogonal(self.U_f)
        init.orthogonal(self.W_o)
        init.orthogonal(self.U_o)
        self.b_f.data.fill_(1.)

    def set_dropout_masks(self, batch_size):
        if self.dropout:
            if self.training:
                self._input_dropout_mask = Variable(torch.bernoulli(
                    torch.Tensor(4, batch_size, self.input_size).fill_(1 - self.dropout)), requires_grad=False)
                self._h_dropout_mask = Variable(torch.bernoulli(
                    torch.Tensor(4, batch_size, self.hidden_size).fill_(1 - self.dropout)), requires_grad=False)

                if torch.cuda.is_available():
                    self._input_dropout_mask = self._input_dropout_mask.cuda()
                    self._h_dropout_mask = self._h_dropout_mask.cuda()
            else:
                self._input_dropout_mask = self._h_dropout_mask = [1. - self.dropout] * 4
        else:
            self._input_dropout_mask = self._h_dropout_mask = [1.] * 4

    def forward(self, input, hidden_state):
        h_tm1, c_tm1 = hidden_state

        if self._input_dropout_mask is None:
            self.set_dropout_masks(input.size(0))

        xi_t = F.linear(input * self._input_dropout_mask[0], self.W_i, self.b_i)
        xf_t = F.linear(input * self._input_dropout_mask[1], self.W_f, self.b_f)
        xc_t = F.linear(input * self._input_dropout_mask[2], self.W_c, self.b_c)
        xo_t = F.linear(input * self._input_dropout_mask[3], self.W_o, self.b_o)

        i_t = F.sigmoid(xi_t + F.linear(h_tm1 * self._h_dropout_mask[0], self.U_i))
        f_t = F.sigmoid(xf_t + F.linear(h_tm1 * self._h_dropout_mask[1], self.U_f))
        c_t = f_t * c_tm1 + i_t * F.tanh(xc_t + F.linear(h_tm1 * self._h_dropout_mask[2], self.U_c))
        o_t = F.sigmoid(xo_t + F.linear(h_tm1 * self._h_dropout_mask[3], self.U_o))
        h_t = o_t * F.tanh(c_t)

        return h_t, c_t

However, it’s about 2 times slower than the built-in LSTMCell. How can I improve its efficiency? Or is there anyway to reuse the built-in LSTMCell to implement it? Thanks in advance!

5 Likes

Are there any plans to add this to pytorch? @apaszke

3 Likes

You can make it somewhat faster by combining matrix multiplies. Here is an LSTM I am using (based on some code in pytorch) that you can use as a model. (Note that it expects to get config off an object cfg rather than through arguments. Also, it implements zoneout and batch norm but not recurrent dropout. Naive batch norm (i.e., not “recurrent” batch norm) is reported in the literature to work poorly because it assumes translation invariance of the sequence, but I have found that training on long sequences it helps.)

class LSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size,
                 norms=cfg.lstm.norms,
                 tie_forget=cfg.lstm.tie_forget,
                 forget_bias=cfg.lstm.forget_bias,
                 activation_function=cfg.lstm.activation_function):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.norms = norms
        self.tie_forget = tie_forget
        self.forget_bias = forget_bias
        self.af = activation_function

        self.matrix_width = 3 if self.tie_forget else 4

        self.combined_weights = nn.Parameter(
            torch.FloatTensor(self.input_size + self.hidden_size, self.matrix_width * self.hidden_size))

        if 'batch' in self.norms:
            self.bn = nn.BatchNorm1d(self.matrix_width * self.hidden_size)
            self.bn_c = nn.BatchNorm1d(self.hidden_size)
        else:
            self.bias = nn.Parameter(torch.FloatTensor(self.matrix_width * self.hidden_size))

        # This seems like a hacky way to implement zoneout, but I'm not sure what the correct way would be
        self.register_buffer('bernoulli_mask',
                             torch.Tensor(1).fill_(cfg.lstm.zoneout).expand((self.hidden_size,)))

        self.reset_parameters()

    def reset_parameters(self):

        # Initialize combine_weights
        weight_ih_data = init.orthogonal(torch.Tensor(self.input_size, self.matrix_width * self.hidden_size))
        weight_hh_data = torch.eye(self.hidden_size).repeat(1, self.matrix_width)
        self.combined_weights.data.set_(torch.cat((weight_hh_data, weight_ih_data), 0))

        if 'batch' in self.norms:
            self.bn.reset_parameters()
            self.bn_c.reset_parameters()
            self.bn.bias.data[0:self.hidden_size].fill_(self.forget_bias)
        else:
            self.bias.data.fill_(0)
            self.bias.data[0:self.hidden_size].fill_(self.forget_bias)


    def forward(self, input, hx):
        """
        Args:
            input: A (batch, input_size) tensor containing input
                features.
            hx: A tuple (h_0, c_0), which contains the initial hidden
                and cell state, where the size of both states is
                (batch, hidden_size).
        Returns:
            h_1, c_1: Tensors containing the next hidden and cell state.
        """

        h_0, c_0 = hx
        combined_inputs = torch.cat((h_0, input), 1)

        if 'batch' in self.norms:
            preactivations = torch.mm(combined_inputs, self.combined_weights)
            preactivations = self.bn(preactivations)
        else:
            batch_size = h_0.size(0)
            bias_batch = (self.bias.unsqueeze(0)
                      .expand(batch_size, *self.bias.size()))
            preactivations = torch.addmm(bias_batch, combined_inputs, self.combined_weights)

        if self.tie_forget:
            fi, o, g   = torch.split(preactivations, split_size=self.hidden_size, dim=1)
            c_1 = torch.sigmoid(fi)*c_0 + torch.sigmoid(1-fi)*self.af(g)
        else:
            f, i, o, g = torch.split(preactivations, split_size=self.hidden_size, dim=1)
            c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*self.af(g)

        h_1 = torch.sigmoid(o) * self.af(self.bn_c(c_1) if 'batch' in self.norms else c_1)

        if cfg.lstm.zoneout > 0:
            if cfg.training:
                h_mask = Variable(torch.bernoulli(self.bernoulli_mask))
                c_mask = Variable(torch.bernoulli(self.bernoulli_mask))
                h_1 = h_0 * h_mask + h_1 * (1-h_mask)
                c_1 = c_0 * c_mask + c_1 * (1-c_mask)
            else:
                h_1 = h_0 * cfg.lstm.zoneout + h_1 * (1-cfg.lstm.zoneout)
                c_1 = c_0 * cfg.lstm.zoneout + c_1 * (1-cfg.lstm.zoneout)

        return h_1, c_1
2 Likes

Did you get a faster version?

I think they mentioned they used a wrapper to use cuda’s implementation on the backend and it is not directly supported by cuda.