How to use multi sample dropout in pretrained pytorch model?

i am loading and using dropout in pretrained vovnet model(from timm package) like this :

class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.head.fc.in_features
        #self.model.head.fc = nn.Linear(n_features, n_class)
        self.model.head.fc = nn.Sequential(
            #nn.Linear(n_features, hidden_size,bias=True), nn.ELU(),
            nn.Linear(n_features, n_class, bias=True)
    def forward(self, x):
        x = self.model(x)
        return x

i want to replace dropout with multi sample dropout : GitHub - lonePatient/multi-sample_dropout_pytorch: a simple pytorch implement of Multi-Sample Dropout

what should be the updated code?

Here is a sample implementation. The key points are:

  • Each channel will be zeroed out independently on every forward call. (from nn.Dropout() docs)
  • Final logits are the average of the logits off all classifiers (from the paper)
  • At test time, passing features through a single classifier is enough (from paper)
  • The nn.CrossEntropyLoss() returns the mean loss by default.

First we create a new module that will take a backbone as feature extractor and a custom classifier. Multi-sample dropout (MSD) will be achieved by iteratively calling the classifier in a for loop and saving the logits. Later, we will deal with the logits inside the training loop.

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import torch.optim as optim

class PretrainedMSD(nn.Module):
    def __init__(self, model, classifier, dropout_num=8):
        super(PretrainedMSD, self).__init__()
        self.features = model
        self.classifier = classifier
        self.dropout_num = dropout_num

    def forward(self, x):
        x = self.features(x)  # Make sure x is flattened before classification layer
            out = []
            for i in range(self.dropout_num):
            out =  # Concatenate the results (size= dropout_num * batch_size, num_classes)
            out = self.classifier(x)
        return out

Below we use DenseNet121 as the backbone and perform a binary classification task using fake image data. During training the MSD is active, and model returns 8 times the normal logits. We simply repeat the original labels to match the new logits shape. Then we reshape the logits and take its average. During evaluation, model returns the normal logits and loss is computed in the normal way.

if __name__ == '__main__':
    base_net = models.densenet121(pretrained=True)
    num_ftrs = base_net.classifier.in_features  # Needed to create the new classifier
    base_net.classifier = nn.Identity()  # By-pass the original classifier

    classifier = nn.Sequential(
        nn.Linear(num_ftrs, 2, bias=True)  # Binary classification

    model = PretrainedMSD(base_net, classifier)  # Default dropout number is 8

    transform = transforms.Compose([
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

    trainset = torchvision.datasets.FakeData(size=1000, num_classes=2, transform=transform)  # Create fake data
    trainloader =, batch_size=32, shuffle=True)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for epoch in range(10):  # loop over the dataset multiple times
        for images, labels in trainloader:
            logits = model(images)
                loss = criterion(logits, labels.repeat(model.dropout_num))  # Average loss
                logits = logits.view(model.dropout_num, -1, 2).mean(0)  # Average logits
                loss = criterion(logits, labels)  # For evaluation
            acc = logits.argmax(dim=-1) == labels
            acc = acc.sum() / len(acc)
            print('Loss: ', loss.item(), 'Acc: ', acc.item())
    print('Finished Training')

If I am not missing anything important this should work. What do you think?