Output in Keypoint detection always converges to zero

Hi everyone,

I am trying to implement a keypoint detector for radio-graphic images. I have read some articles about keypoint detection in persons and the dominant approach was to use a hourglass architecture that outputs a map with one channel for each point. In order to train the model, most references created the ground truth maps by, for each channels (points), placing a Gaussian centered on the point with a small variance (about 1 px). Then MSE as loss is used on training.

I tried to replicate this approach, but my models’s outputs just becomes a tensor of zeros for all channels very rapidly, what is understandably, but the model rapidly converges to this, and then it spends dozens of epochs with no improvement on the loss.
I really cannot figure out what I might be doing wrong.

This is the architecture I am using:

class UConvDown(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=(3, 3),
        stride=(1, 1),
        *args,
        **kargs,
    ):
        super().__init__()

        self.conv0 = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride, padding=1
        )
        self.conv1 = nn.Conv2d(
            out_channels, out_channels, kernel_size, (1, 1), padding=1
        )
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)

    def forward(self, x):

        x = self.conv0(x)
        x = F.relu(x)
        x = self.conv1(x)
        conv = F.relu(x)
        pooled = self.pool(conv)

        return conv, pooled


class UConvUp(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=(3, 3),
        stride=(1, 1),
        *args,
        **kargs,
    ):
        super().__init__()

        self.convt = nn.ConvTranspose2d(
            in_channels, out_channels, (2, 2), (2, 2), padding=0
        )
        self.conv0 = nn.Conv2d(
            out_channels * 2, out_channels, kernel_size, (1, 1), padding=1
        )
        self.conv1 = nn.Conv2d(
            out_channels, out_channels, kernel_size, (1, 1), padding=1
        )

    def forward(self, x, conv):

        x = self.convt(x)
        x = torch.cat((x, conv), dim=1)
        x = self.conv0(x)
        x = F.relu(x)
        # pad = compute_padding(x.shape, self.kernel_size, (1, 1)))
        x = self.conv1(x)
        x = F.relu(x)

        return x


class SimpleUnet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = UConvDown(in_channels, 32)
        self.conv2 = UConvDown(32, 64)
        self.conv3 = UConvDown(64, 128)
        self.conv4 = UConvDown(128, 256)
        self.conv5 = UConvDown(256, 512)

        self.conv6 = UConvUp(512, 256)
        self.conv7 = UConvUp(256, 128)
        self.conv8 = UConvUp(128, 64)
        self.conv9 = UConvUp(64, 32)

        self.conv10 = nn.Conv2d(32, out_channels, (1, 1), (1, 1))

    def forward(self, x):

        conv1, x = self.conv1(x)
        conv2, x = self.conv2(x)
        conv3, x = self.conv3(x)
        conv4, x = self.conv4(x)
        conv5, x = self.conv5(x)

        x = self.conv6(conv5, conv4)
        x = self.conv7(x, conv3)
        x = self.conv8(x, conv2)
        x = self.conv9(x, conv1)

        x = self.conv10(x)

        return x

I am using pytorch-lightning to handle the training loop.

This is an example of the output for a certain point:


On the left I have the ground truth for one of the channels, a dot formed by a Gaussian with one pixel of variance. On the right the outputted map on epoch 49. My dataset has around 20k examples. By epoch 49 it should have seen enough examples.

Here is my training step:

   def training_step(self, batch, batch_idx):

       x, mask = batch
       mask_pred = self(x)
       mask_pred = torch.sigmoid(mask_pred)
       loss = F.mse_loss(mask_pred, mask)

       self.log(
           "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
       )

       if batch_idx % 200 == 0:
           self.log_masks(mask, mask_pred, batch_idx)

       return loss

The Gaussian is normalized around 0 - 1 so I apply Sigmoid to the output.

Based on the target map you’ve posted it seems you are dealing with a heavily imbalanced dataset, i.e. the majority of all pixels belongs to the background class while only a small fraction represents the other class.
The model might thus easily overfit to the background class, as it’ll decrease the loss easily and also would yield a high accuracy.
You could try to use e.g. a weighted loss in order to force the model to learn class1.

Thank you for you reply. I though that might be an issue. Most of the pixels on the map are zero, but the non-zero pixels belongs to a normalized Gaussian (they are float values varying from 0 to 1), so I do not have positive class, how would I approach weighting in this case, i.e, when I have MSE as loss and continuous float valued targets?

I am using this repository as reference: https://github.com/microsoft/human-pose-estimation.pytorch

The loss they use is defined as follows:

class JointsMSELoss(nn.Module):
    def __init__(self, use_target_weight):
        super(JointsMSELoss, self).__init__()
        self.criterion = nn.MSELoss(size_average=True)
        self.use_target_weight = use_target_weight

    def forward(self, output, target, target_weight):
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
        loss = 0

        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                loss += 0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                )
            else:
                loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)

        return loss / num_joints

It seems to be a simple MSE loss with a weight applied to each channel as a whole, what would not fix the imbalance between zero and nonzero values.

You could create a weight tensor in the same shape as your output and target e.g. via:

# setup
target = torch.empty(2, 1, 24, 24).uniform_(0, 1)
output = torch.randn(2, 1, 24, 24)

# weight map creation
threshold = 0.2
weight = target > threshold
weight_value = 10.
weight = weight.float() * weight_value

# weighted loss
loss = F.mse_loss(output, target, reduction='none') * weight