Implementing variational dropout causes nan values

I’m trying to implement Variational Dropout to Recurrent Highway Network.
Because my implementation of the RHN is in for loop for timesteps, I need to save the mask for every time-step for future use. In each batch, I’m creating a new mask.

This is my function for creating the mask:

class MaskDropout(nn.Module):
    """Same as Lockdropout, but for a single time-step - using the same mask"""
    def __init__(self, dropout=0.3):
        super(MaskDropout, self).__init__()
        self.mask = None
        self.dropout = dropout

    def reset(self, x):
        """ MUST BE CALLED BEFORE FORWARD! call this function every batch to reset the mask """
        m = x.data.new(x.size()).bernoulli_(1 - self.dropout)
        m = m.div_(1 - self.dropout)
        self.mask = torch.autograd.Variable(m, requires_grad=False)

    def forward(self, x):
        """forward function for dropout implementation (mask is same)"""
        if not self.training or not self.dropout:
            return x
        return self.mask * x

This is an example of feedforward with a mask:

    def forward(self, input, s_t0_l0):
        _st_batch_size = s_t0_l0.size(0)
        _st_hid_dim = s_t0_l0.size(1)

        _timesteps = input.size(0)
        _batch_size = input.size(1)
        _features = input.size(2)

        inputs = list(input.unbind(0))  # splitting inputs into time-steps (list of ts elements of [bs x feats])
        s_t_minus_1 = s_t0_l0

        self.maskdrop.reset(s_t0_l0)                # resetting the variational dropout masks

        for t in range(len(inputs)):
            # s_t_L = output of RHN only
            s_t_L = self.rhncell(inputs[t], s_t_minus_1)  # RHN Cell
            # s_t = output of HSG (after RHN) - This is an extension, you can skip this line..
            s_t = self.hsgcell(s_t_L, s_t_minus_1)  # HSG Cell

            # variational dropout for s_t
            s_t = self.maskdrop(s_t)     

            s_t_minus_1 = s_t
            outputs += [s_t]

        rhn_outputs = torch.stack(outputs)
        dense_out = self.output_fc(rhn_outputs)
        ...

Now the problem is that with the following line, m = m.div_(1 - self.dropout) , at some point, there are nan values. Without it everything is perfect, but it’s a wrong implementation of dropout in general…

My theory is that multiplying each time the same values by a value greater than 1 will cause gradients explosion, but this has been done in the past (as you can see from implementing variational dropout for LSTM networks…)
What am I missing here?

Thank you!

The problem was that I was implementing dropout on a linear function which is after the activation function.
According to the paper on Variational Dropout I’ve mentioned before, one should implement the dropout mask before the non-linear activation (sigmoid / tanh in my case…)