Dropout for RNNs

In Torch7, Dropout in the RNN library, https://github.com/Element-Research/rnn/blob/master/Dropout.lua, allows a sequence to have the same dropout mask for different time step for consistent masking.

I wonder if there would be an elegant way to use the same dropout mask on sequences for RNNs, or it would be better to implement a module.

(Dropout option in the current RNN module just regard the entire sequence output as a single output. Right?)

Thanks!

2 Likes

You could write your own module, where you process the whole sequence, sample the mask once at the beginning and just do an element-wise multiplication after each step. Here’s a gist:

def forward(self, x):
    mask = Variable(torch.Tensor(x.size(0), self.hidden_size).fill_(self.dropout).bernoulli_())
    hidden = None
    outputs = []
    for t in x.size(1):
        output, hidden = self.rnn(x[:, t], hidden)
        outputs.append(output * mask)
    return outputs, hidden
3 Likes

Regarding the current RNN modules, they use a different mask for each time step.

2 Likes

Performance-wise, running rnn on the whole input sequence, expanding mask and applying it to the whole rnn output will probably be better than having a loop over time.

4 Likes

I am looking for a way of applying dropout in between layers for a stacked LSTM.
I have code implementing this using the cell version of the LSTM and running it on each step of the sequence and I can attest a slow down of a factor of 5 (from 40.000 tokens/s to 8.000 tokens/s).

@emanjavacas do you want to use a single mask for all timesteps, or a separate mask per timestep?

Sorry, I was referring to the previous example which is on a timestep basis. For the other case I believe using LSTM(..., dropout=dropout) shoud be enough?

Yes, if a separate mask for each timestep is ok, then you can just use the built in modules. They will be the fastest.

Found that same mask for each time step is also simple by just inheriting torch.nn._functions.dropout.Dropout and overriding as follows (assuming the input is seqlen X batchsize X dim):

def _make_noise:
	return input.new().resize_(1, input.size(1), input.size(2))'

@supakjk not sure how you use that module, but note that depending on undocumented methods (especially those prefixed with an underscore) isn’t recommended, as they can change without notice.

What I did to use the same dropout mask for different time steps was inheriting classes as follows:

class SeqConstDropoutFunc(torch.nn._functions.dropout.Dropout):
        def __init__(self, p=0.5, train=False, inplace=False):
                super(SeqConstDropoutFunc, self).__init__(p, train, inplace)

        def _make_noise(self, input):   # for timesteps X batches X dims inputs, let each time step has the same dropout mask
                return input.new().resize_(1, input.size(1), input.size(2))

class SeqConstDropout(nn.Dropout):
        def __init__(self, p=0.5, inplace=False):
                super(SeqConstDropout, self).__init__(p, inplace)

        def forward(self, input):
                return SeqConstDropoutFunc(self.p, self.training, self.inplace)(input)

It seems that overring _make_noise isn’t a good idea. Then, I’ll either notice the changes of the code or make an independent dropout class.

Thanks.

Yes I think it’s best to reimplement it yourself. It’s a very simple function.

I think you should sample the bernoulli using:

Variable(torch.bernoulli(torch.Tensor(x.size(0), self.hidden_size).fill_(self.dropout)))

otherwise your mask won’t be correct when using self.dropout = 1.0 or self.dropout = 0.0. I just ran into it, and found that it is an issue on GitHub.

1 Like

Is there an elegant implementation of it already? I guess the idea of using the same dropout mask is from “A Theoretically Grounded Application of Dropout in Recurrent Neural Networks”.

Did you get it implemented in pytorch already?

Variable(torch.Tensor(x.size(0), self.hidden_size).fill_(self.dropout).bernoulli_()) is not correct.

The correct form is Variable(torch.Tensor(x.size(0), self.hidden_size).fill_(self.dropout).bernoulli())

1 Like