Concrete Dropout implementation?

We would like to have PyTorch version of Concrete Dropout, the original Keras code in the link. Could someone help to take a look if it makes more sense to write it to be within _functions (same as dropout with both forward and backward) or a class extended nn.Module (with forward only and wrap trainable p as Variable())

looks like self.p is learnable. You can implement a nn.Module (not _functions, i dont think it’s needed to implement your own autograd backward here, it looks pretty simple).

Did you end up implementing this?

Hi!
I’ve attempted to implement concrete dropout in pytorch according to Yarin Gal’s repo. Notebook with an attempt to replicate the experiment can be found here on github.

I suspect that I have missed something in the regularisation part of the model, hence poor replication of the original experiment (see link above).

If someone is willing to help out or provide an advice on implementation and making the model work, I would be happy to learn!

At the moment, Concrete Dropout layer and the corresponding model are defined as such:

class ConcreteDropout(nn.Module):
    
    def __init__(self, layer, input_shape, wr=1e-6, dr=1e-5, init_min=0.1, init_max=0.1):
        super(ConcreteDropout, self).__init__()
        
        # Post dropout layer
        self.layer = layer
        # Input dim
        self.input_dim = np.prod(input_shape)
        # Regularisation hyper-params
        self.w_reg_param = wr
        self.d_reg_param = dr
        
        # Initialise p_logit
        init_min = np.log(init_min) - np.log(1. - init_min)
        init_max = np.log(init_max) - np.log(1. - init_max)
        
        self.p_logit = nn.Parameter(torch.Tensor(1))
        nn.init.uniform(self.p_logit, a=init_min, b=init_max)
        
    def sum_of_square(self):
        """
        For paramater regularisation
        """
        sum_of_square = 0
        for param in self.layer.parameters():
            sum_of_square += torch.sum(torch.pow(param, 2))
        return sum_of_square
    
    def regularisation(self):
        """
        Returns regularisation term, should be added to the loss
        """
        weights_regularizer = self.w_reg_param * self.sum_of_square() / (1 - self.p)
        dropout_regularizer = self.p * torch.log(self.p)
        dropout_regularizer += (1. - self.p) * torch.log(1. - self.p)
        dropout_regularizer *= self.d_reg_param * self.input_dim
        regularizer = weights_regularizer + dropout_regularizer
        return regularizer
    
    def forward(self, x):
        """
        Forward pass for dropout layer
        """
        eps = 1e-7
        temp = 0.1
        
        self.p = nn.functional.sigmoid(self.p_logit)

        unif_noise = np.random.uniform()

        drop_prob = (torch.log(self.p + eps) 
                    - torch.log(1 - self.p + eps)
                    + np.log(unif_noise + eps)
                    - np.log(1 - unif_noise + eps))
        drop_prob = nn.functional.sigmoid(drop_prob / temp)
        random_tensor = 1 - drop_prob
        retain_prob = 1 - self.p

        x  = torch.mul(x, random_tensor)
        x /= retain_prob
        
        return self.layer(x)
    
class Linear_relu(nn.Module):
    
    def __init__(self, inp, out):
        super(Linear_relu, self).__init__()
        self.model = nn.Sequential(nn.Linear(inp, out), nn.ReLU())
        
    def forward(self, x):
        return self.model(x)


class Model(nn.Module):
    
    def __init__(self, wr, dr):
        super(Model, self).__init__()
        self.forward_main = nn.Sequential(
                  ConcreteDropout(Linear_relu(1, nb_features), input_shape=1, wr=wr, dr=dr),
                  ConcreteDropout(Linear_relu(nb_features, nb_features), input_shape=nb_features, wr=wr, dr=dr),
                  ConcreteDropout(Linear_relu(nb_features, nb_features), input_shape=nb_features, wr=wr, dr=dr))
        self.forward_mean = ConcreteDropout(Linear_relu(nb_features, D), input_shape=nb_features, wr=wr, dr=dr)
        self.forward_logvar = ConcreteDropout(Linear_relu(nb_features, D), input_shape=nb_features, wr=wr, dr=dr)
        
    def forward(self, x):
        x = self.forward_main(x)
        mean = self.forward_mean(x)
        log_var = self.forward_logvar(x)
        return mean, log_var

    def heteroscedastic_loss(self, true, mean, log_var):
        precision = torch.exp(-log_var)
        return torch.sum(precision * (true - mean)**2 + log_var)
    
    def regularisation_loss(self):
        reg_loss = self.forward_main[0].regularisation()+self.forward_main[1].regularisation()+self.forward_main[2].regularisation()
        reg_loss += self.forward_mean.regularisation()
        reg_loss += self.forward_logvar.regularisation()
        return reg_loss

Thank you!

All works fine now, uniform noise was missing the size part:

unif_noise = Variable(torch.FloatTensor(np.random.uniform(size=tuple(x.size()))))