I’m trying to implement Variational Dropout to Recurrent Highway Network.
Because my implementation of the RHN is in for loop for timesteps, I need to save the mask for every time-step for future use. In each batch, I’m creating a new mask.
This is my function for creating the mask:
class MaskDropout(nn.Module):
"""Same as Lockdropout, but for a single time-step - using the same mask"""
def __init__(self, dropout=0.3):
super(MaskDropout, self).__init__()
self.mask = None
self.dropout = dropout
def reset(self, x):
""" MUST BE CALLED BEFORE FORWARD! call this function every batch to reset the mask """
m = x.data.new(x.size()).bernoulli_(1 - self.dropout)
m = m.div_(1 - self.dropout)
self.mask = torch.autograd.Variable(m, requires_grad=False)
def forward(self, x):
"""forward function for dropout implementation (mask is same)"""
if not self.training or not self.dropout:
return x
return self.mask * x
This is an example of feedforward with a mask:
def forward(self, input, s_t0_l0):
_st_batch_size = s_t0_l0.size(0)
_st_hid_dim = s_t0_l0.size(1)
_timesteps = input.size(0)
_batch_size = input.size(1)
_features = input.size(2)
inputs = list(input.unbind(0)) # splitting inputs into time-steps (list of ts elements of [bs x feats])
s_t_minus_1 = s_t0_l0
self.maskdrop.reset(s_t0_l0) # resetting the variational dropout masks
for t in range(len(inputs)):
# s_t_L = output of RHN only
s_t_L = self.rhncell(inputs[t], s_t_minus_1) # RHN Cell
# s_t = output of HSG (after RHN) - This is an extension, you can skip this line..
s_t = self.hsgcell(s_t_L, s_t_minus_1) # HSG Cell
# variational dropout for s_t
s_t = self.maskdrop(s_t)
s_t_minus_1 = s_t
outputs += [s_t]
rhn_outputs = torch.stack(outputs)
dense_out = self.output_fc(rhn_outputs)
...
Now the problem is that with the following line, m = m.div_(1 - self.dropout)
, at some point, there are nan values. Without it everything is perfect, but it’s a wrong implementation of dropout in general…
My theory is that multiplying each time the same values by a value greater than 1 will cause gradients explosion, but this has been done in the past (as you can see from implementing variational dropout for LSTM networks…)
What am I missing here?
Thank you!