Hi there,
my name is Steven and this my first post in this forum. I apologize in advance for the long post but I want to be as clear as possible. Here’s my problem:
I’m trying to train a model on the CityScapes dataset to not only output a segmentation map, but also uncertainties. The model I use is custom DeepLabV3+ model with a MobileNetV3 backbone. I replaced the regular output head with two new output heads that I defined as follows:
class SegmentationHead(nn.Sequential):
def __init__(self, in_channels, num_classes):
conv2d_3x3 = nn.Conv2d(in_channels, out_channels=256, kernel_size=3, padding=3//2)
batch_norm = nn.BatchNorm2d(256)
relu = nn.ReLU()
conv2d_1x1 = nn.Conv2d(256, out_channels=num_classes, kernel_size=1)
upsampling = nn.UpsamplingBilinear2d(scale_factor=4)
super().__init__(conv2d_3x3, batch_norm, relu, conv2d_1x1, upsampling)
class UncertaintyHead(nn.Sequential):
def __init__(self, in_channels, num_classes):
conv2d_3x3 = nn.Conv2d(in_channels, out_channels=256, kernel_size=3, padding=3//2)
batch_norm = nn.BatchNorm2d(256)
relu = nn.ReLU()
conv2d_1x1 = nn.Conv2d(256, out_channels=num_classes, kernel_size=1)
upsampling = nn.UpsamplingBilinear2d(scale_factor=4)
activation = nn.Sigmoid()
super().__init__(conv2d_3x3, batch_norm, relu, conv2d_1x1, upsampling, activation)
I changed the forward pass method accordingly:
def forward(self, x):
features = self.encoder(x)
decoder_output = self.decoder(*features)
segmentation_map = self.segmentation_head(decoder_output)
uncertainty_map = self.uncertainty_head(decoder_output)
return segmentation_map, uncertainty_map
The output shapes of both heads are as expected. If I input a batch of images with shape [batch_size, num_channels, img_height, img_width], I get an output shape of [batch_size, num_classes, img_height, img_width] for both heads.
Now to the part where I’m currently stuck, which is training the model. I use Pytorch Lightning and this how the training_step looks:
...
def __init__(...):
...
self.classification_criterion = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')
self.distillation_criterion = nn.KLDivLoss(reduction='mean')
self.uncertainty_criterion = nn.MSELoss(reduction='mean')
def training_step(self, batch, batch_index):
images, labels = batch
# Forward Pass Student
prediction_map, uncertainty_map = self.model(images)
# Forward Pass Teacher
ensemble_outputs = torch.empty(size=[len(self.ensemble), self.batch_size, 20, 768, 768], device=DEVICE)
for i, model in enumerate(self.ensemble):
model.cuda()
model.eval()
with torch.no_grad():
output = model(images)
ensemble_outputs[i] = output
ensemble_probability_map = torch.mean(torch.softmax(ensemble_outputs, dim=2), dim=0).to(dtype=torch.float16)
ensemble_uncertainty_map = torch.std(torch.softmax(ensemble_outputs, dim=2), dim=0).to(dtype=torch.float16)
# Loss Calculation
classification_loss = self.classification_criterion(prediction_map, labels.squeeze())
distillation_loss = self.distillation_criterion(torch.log_softmax(prediction_map, dim=1), ensemble_mean_prediction)
uncertainty_loss = self.uncertainty_criterion(uncertainty_map, ensemble_uncertainty_prediction)
loss = 0.1 * classification_loss + distillation_loss + uncertainty_loss
return {'loss': loss}
The major problem that I can’t seem to understand is the following: The loss is slowly declining (also all three components of the combined loss separately) and the segmentation map looks good but the uncertainty outputs slowly converge towards 0 instead of learning what the uncertainties are supposed to be ( = ensemble_uncertainty_map). For some context, ensemble_uncertainty_map contains values in the range [0, 1] with a majority of the values being close to 0 but also some values that are > 0.
This should be a solvable regression problem - at least there is scientific literature out there that has done similar work (but without open sourcing the code unfortunately). But what am I missing?
Thank you so much in advance and feel free to ask any questions!