Bayesian Deep Learning with monte carlo dropout Pytorch

I am trying to implement Bayesian CNN using Mc Dropout on Pytorch, the main idea is that by applying dropout at test time and running over many forward passes, you get predictions from a variety of different models. I need to obtain the uncertainty, does anyone have an idea of how I can do it Please
This is how I defined my CNN

class Net(nn.Module):
   def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
    self.dropout = nn.Dropout(p=0.3)

    nn.init.xavier_uniform_(self.conv1.weight)
    nn.init.constant_(self.conv1.bias, 0.0)
    nn.init.xavier_uniform_(self.conv2.weight)
    nn.init.constant_(self.conv2.bias, 0.0)
    nn.init.xavier_uniform_(self.fc1.weight)
    nn.init.constant_(self.fc1.bias, 0.0)
    nn.init.xavier_uniform_(self.fc2.weight)
    nn.init.constant_(self.fc2.bias, 0.0)
    nn.init.xavier_uniform_(self.fc3.weight)
    nn.init.constant_(self.fc3.bias, 0.0)

  def forward(self, x):
    x = self.pool(F.relu(self.dropout(self.conv1(x))))  # recommended to add the relu
    x = self.pool(F.relu(self.dropout(self.conv2(x))))  # recommended to add the relu
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(self.dropout(x)))
    x = self.fc3(self.dropout(x))  # no activation function needed for the last layer
    return x


  model = Net().to(device)


  train_accuracies=np.zeros(num_epochs)
  test_accuracies=np.zeros(num_epochs)

  dataiter = iter(trainloader)
  images, labels = dataiter.next()
  #initializing variables

  loss_acc = []
  class_acc_mcdo = []
  start_train = True

  #Defining the Loss Function and Optimizer
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


  def train():
      loss_vals = []
      acc_vals = []

      for epoch in range(num_epochs):  # loop over the dataset multiple times

          n_correct = 0  # initialize number of correct predictions
          acc = 0  # initialize accuracy of each epoch
          somme = 0  # initialize somme of losses of each epoch
          epoch_loss = []

          for i, (images, labels) in enumerate(trainloader):
              # origin shape: [4, 3, 32, 32] = 4, 3, 1024
              # input_layer: 3 input channels, 6 output channels, 5 kernel size
              images = images.to(device)
              labels = labels.to(device)

              # Forward pass
              outputs = model.train()(images)
              loss = criterion(outputs, labels)

              # Backward and optimize
              optimizer.zero_grad()  # zero the parameter gradients
              loss.backward()
              epoch_loss.append(loss.item())  # add the loss to epoch_loss list
              optimizer.step()
              # max returns (value ,index)
              _, predicted = torch.max(outputs, 1)
              n_correct += (predicted == labels).sum().item()

              # print statistics
              if (i + 1) % 2000 == 0:
                  print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{n_total_steps}], Loss:             
                  {loss.item():.4f}')
            

          somme = (sum(epoch_loss)) / len(epoch_loss)
          loss_vals.append(somme)  # add the epoch's loss to loss_vals

        

          print("Loss = {}".format(somme))
          acc = 100 * n_correct / len(trainset)
          acc_vals.append(acc)  # add the epoch's Accuracy to acc_vals
          print("Accuracy = {}".format(acc))

                # SAVE
          PATH = './cnn.pth'
          torch.save(model.state_dict(), PATH)

          loss_acc.append(loss_vals)
          loss_acc.append(acc_vals)
          return loss_acc

And here is the code of the mc dropout

def enable_dropout(model):
    """ Function to enable the dropout layers during test-time """
    for m in model.modules():
        if m.__class__.__name__.startswith('Dropout'):
            m.train()
  
def test():
    # set non-dropout layers to eval mode
    model.eval()

    # set dropout layers to train mode
    enable_dropout(model)

    test_loss = 0
    correct = 0
    n_samples = 0
    n_class_correct = [0 for i in range(10)]
    n_class_samples = [0 for i in range(10)]
    T = 100

    for images, labels in testloader:
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            output_list = []
            # getting outputs for T forward passes
            for i in range(T):
                output_list.append(torch.unsqueeze(model(images), 0))

        # calculating mean
        output_mean = torch.cat(output_list, 0).mean(0)

        test_loss += F.nll_loss(F.log_softmax(output_mean, dim=1), labels,
                            reduction='sum').data  # sum up batch loss
        _, predicted = torch.max(output_mean, 1)  # get the index of the max log-probability
        correct += (predicted == labels).sum().item()  # sum up correct predictions
        n_samples += labels.size(0)

        for i in range(batch_size):
            label = labels[i]
            predi = predicted[i]

            if (label == predi):
                n_class_correct[label] += 1
            n_class_samples[label] += 1

    test_loss /= len(testloader.dataset)

    # PRINT TO HTML PAGE
    print('\n Average loss: {:.4f}, Accuracy:  ({:.3f}%)\n'.format(
        test_loss,
        100. * correct / n_samples))


    # Accuracy for each class
    acc_classes = []
    for i in range(10):
        acc = 100.0 * n_class_correct[i] / n_class_samples[i]
        print(f'Accuracy of {classes[i]}: {acc} %')
        acc_classes.append(acc)

    class_acc_mcdo.extend(acc_classes)
    print('Finished Testing')

Here you’re creating a sampling distribution, and only using it for .mean(0), you can use it to calculate uncertainty estimates that you need.

Something similar to torch.cat(output_list, 0).softmax(1).sd(0) (and gather predicted classes)

1 Like

I’m working on classification of images, I need to calculate the entropy for each class how could I do that ( sum up each sample’s entropy or how exactly Please!

Perhaps start with

scores = torch.cat(output_list,0)
probs = torch.softmax(scores, dim=-1)
log_probs = probs.clamp_min(1e-6).log()
per_class_entropy = probs*log_probs #shape: (num_models, batch, num_classes)

At this point, I’m not sure about “per-class entropy” formula, that would dictate how to reduce the last tensor. Perhaps -per_class_entropy.mean((0,1))

1 Like

probs gives you the probability of each class per batch?

Yes, I think so. ```

1 Like

Am working on this
What do you think any advice/opinon ? please

@MERAH_Samia Can you please how you could implement MC dropout in Pytorch?

Send me an email: merahsamia@gmail.com