Uncertainty Regression: Student-Teacher Distillation)

Hi there,

my name is Steven and this my first post in this forum. I apologize in advance for the long post but I want to be as clear as possible. Here’s my problem:

I’m trying to train a model on the CityScapes dataset to not only output a segmentation map, but also uncertainties. The model I use is custom DeepLabV3+ model with a MobileNetV3 backbone. I replaced the regular output head with two new output heads that I defined as follows:

class SegmentationHead(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        conv2d_3x3 = nn.Conv2d(in_channels, out_channels=256, kernel_size=3, padding=3//2)
        batch_norm = nn.BatchNorm2d(256)
        relu = nn.ReLU()
        conv2d_1x1 = nn.Conv2d(256, out_channels=num_classes, kernel_size=1)
        upsampling = nn.UpsamplingBilinear2d(scale_factor=4)
        super().__init__(conv2d_3x3, batch_norm, relu, conv2d_1x1, upsampling)

class UncertaintyHead(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        conv2d_3x3 = nn.Conv2d(in_channels, out_channels=256, kernel_size=3, padding=3//2)
        batch_norm = nn.BatchNorm2d(256)
        relu = nn.ReLU()
        conv2d_1x1 = nn.Conv2d(256, out_channels=num_classes, kernel_size=1)
        upsampling = nn.UpsamplingBilinear2d(scale_factor=4)
        activation = nn.Sigmoid()
        super().__init__(conv2d_3x3, batch_norm, relu, conv2d_1x1, upsampling, activation)

I changed the forward pass method accordingly:

def forward(self, x):
    features = self.encoder(x)
    decoder_output = self.decoder(*features)
    segmentation_map = self.segmentation_head(decoder_output)
    uncertainty_map = self.uncertainty_head(decoder_output)
    return segmentation_map, uncertainty_map

The output shapes of both heads are as expected. If I input a batch of images with shape [batch_size, num_channels, img_height, img_width], I get an output shape of [batch_size, num_classes, img_height, img_width] for both heads.

Now to the part where I’m currently stuck, which is training the model. I use Pytorch Lightning and this how the training_step looks:

...
    def __init__(...):
        ...
        self.classification_criterion = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')
        self.distillation_criterion = nn.KLDivLoss(reduction='mean')
        self.uncertainty_criterion = nn.MSELoss(reduction='mean')

    def training_step(self, batch, batch_index):
        images, labels = batch

        # Forward Pass Student
        prediction_map, uncertainty_map = self.model(images)

        # Forward Pass Teacher
        ensemble_outputs = torch.empty(size=[len(self.ensemble), self.batch_size, 20, 768, 768], device=DEVICE)
        for i, model in enumerate(self.ensemble):
            model.cuda()
            model.eval()
            with torch.no_grad():
                output = model(images)
            ensemble_outputs[i] = output

        ensemble_probability_map = torch.mean(torch.softmax(ensemble_outputs, dim=2), dim=0).to(dtype=torch.float16) 
        ensemble_uncertainty_map = torch.std(torch.softmax(ensemble_outputs, dim=2), dim=0).to(dtype=torch.float16)

        # Loss Calculation
        classification_loss = self.classification_criterion(prediction_map, labels.squeeze())
        distillation_loss = self.distillation_criterion(torch.log_softmax(prediction_map, dim=1), ensemble_mean_prediction)
        uncertainty_loss = self.uncertainty_criterion(uncertainty_map, ensemble_uncertainty_prediction)         

        loss = 0.1 * classification_loss + distillation_loss + uncertainty_loss

        return {'loss': loss}

The major problem that I can’t seem to understand is the following: The loss is slowly declining (also all three components of the combined loss separately) and the segmentation map looks good but the uncertainty outputs slowly converge towards 0 instead of learning what the uncertainties are supposed to be ( = ensemble_uncertainty_map). For some context, ensemble_uncertainty_map contains values in the range [0, 1] with a majority of the values being close to 0 but also some values that are > 0.

This should be a solvable regression problem - at least there is scientific literature out there that has done similar work (but without open sourcing the code unfortunately). But what am I missing?

Thank you so much in advance and feel free to ask any questions!

That’s an interesting use case and based on your description it seems you are concerned about the uncertainty_map outputs as they don’t seem to fit the target after a few training steps and instead collapse to a zero output.
You are also describing that “some values are > 0” in the corresponding target while the majority is close to 0.
Lets assume the worst case and a prediction full of perfect 0s in uncertainty_map, how large would the uncertainty_loss be in this case? If it’s small in comparison to the other two losses, this behavior might be expected, since your model training would try to reduce the “global” loss (i.e. the sum of all losses) and you might want to use different scaling factors (you are already reducing the classification_loss a lot).

1 Like

The uncertainty_loss is very low if uncertainty_map is full of perfects 0s which is why I introduced the loss scaling in the first place. Because of that I tried only using the uncertainty_loss which also didn’t work out - same behaviour as with the combined loss.

I also tried manipulating the uncertainty values by changing the input of the MSE-Loss. Here’s an example:

uncertainty_loss = self.uncertainty_criterion(torch.pow(uncertainty_map, 0.2), torch.pow(ensemble_uncertainty_prediction, 0.2))

No matter what I tried, I can’t even get the network to overfit the uncertainties of a couple of images.

Does this mean that your model isn’t able to overfit the uncertainties for a small dataset if this loss is used alone (without the addition of other losses)?

1 Like

Yes, exactly.

For reference, I’m trying to rebuild what this paper did.