We would like to have PyTorch version of Concrete Dropout, the original Keras code in the link. Could someone help to take a look if it makes more sense to write it to be within _functions (same as dropout with both forward and backward) or a class extended nn.Module (with forward only and wrap trainable p as Variable())
looks like self.p
is learnable. You can implement a nn.Module
(not _functions, i dont think it’s needed to implement your own autograd backward here, it looks pretty simple).
Did you end up implementing this?
Hi!
I’ve attempted to implement concrete dropout in pytorch according to Yarin Gal’s repo. Notebook with an attempt to replicate the experiment can be found here on github.
I suspect that I have missed something in the regularisation part of the model, hence poor replication of the original experiment (see link above).
If someone is willing to help out or provide an advice on implementation and making the model work, I would be happy to learn!
At the moment, Concrete Dropout layer and the corresponding model are defined as such:
class ConcreteDropout(nn.Module):
def __init__(self, layer, input_shape, wr=1e-6, dr=1e-5, init_min=0.1, init_max=0.1):
super(ConcreteDropout, self).__init__()
# Post dropout layer
self.layer = layer
# Input dim
self.input_dim = np.prod(input_shape)
# Regularisation hyper-params
self.w_reg_param = wr
self.d_reg_param = dr
# Initialise p_logit
init_min = np.log(init_min) - np.log(1. - init_min)
init_max = np.log(init_max) - np.log(1. - init_max)
self.p_logit = nn.Parameter(torch.Tensor(1))
nn.init.uniform(self.p_logit, a=init_min, b=init_max)
def sum_of_square(self):
"""
For paramater regularisation
"""
sum_of_square = 0
for param in self.layer.parameters():
sum_of_square += torch.sum(torch.pow(param, 2))
return sum_of_square
def regularisation(self):
"""
Returns regularisation term, should be added to the loss
"""
weights_regularizer = self.w_reg_param * self.sum_of_square() / (1 - self.p)
dropout_regularizer = self.p * torch.log(self.p)
dropout_regularizer += (1. - self.p) * torch.log(1. - self.p)
dropout_regularizer *= self.d_reg_param * self.input_dim
regularizer = weights_regularizer + dropout_regularizer
return regularizer
def forward(self, x):
"""
Forward pass for dropout layer
"""
eps = 1e-7
temp = 0.1
self.p = nn.functional.sigmoid(self.p_logit)
unif_noise = np.random.uniform()
drop_prob = (torch.log(self.p + eps)
- torch.log(1 - self.p + eps)
+ np.log(unif_noise + eps)
- np.log(1 - unif_noise + eps))
drop_prob = nn.functional.sigmoid(drop_prob / temp)
random_tensor = 1 - drop_prob
retain_prob = 1 - self.p
x = torch.mul(x, random_tensor)
x /= retain_prob
return self.layer(x)
class Linear_relu(nn.Module):
def __init__(self, inp, out):
super(Linear_relu, self).__init__()
self.model = nn.Sequential(nn.Linear(inp, out), nn.ReLU())
def forward(self, x):
return self.model(x)
class Model(nn.Module):
def __init__(self, wr, dr):
super(Model, self).__init__()
self.forward_main = nn.Sequential(
ConcreteDropout(Linear_relu(1, nb_features), input_shape=1, wr=wr, dr=dr),
ConcreteDropout(Linear_relu(nb_features, nb_features), input_shape=nb_features, wr=wr, dr=dr),
ConcreteDropout(Linear_relu(nb_features, nb_features), input_shape=nb_features, wr=wr, dr=dr))
self.forward_mean = ConcreteDropout(Linear_relu(nb_features, D), input_shape=nb_features, wr=wr, dr=dr)
self.forward_logvar = ConcreteDropout(Linear_relu(nb_features, D), input_shape=nb_features, wr=wr, dr=dr)
def forward(self, x):
x = self.forward_main(x)
mean = self.forward_mean(x)
log_var = self.forward_logvar(x)
return mean, log_var
def heteroscedastic_loss(self, true, mean, log_var):
precision = torch.exp(-log_var)
return torch.sum(precision * (true - mean)**2 + log_var)
def regularisation_loss(self):
reg_loss = self.forward_main[0].regularisation()+self.forward_main[1].regularisation()+self.forward_main[2].regularisation()
reg_loss += self.forward_mean.regularisation()
reg_loss += self.forward_logvar.regularisation()
return reg_loss
Thank you!
All works fine now, uniform noise was missing the size part:
unif_noise = Variable(torch.FloatTensor(np.random.uniform(size=tuple(x.size()))))