Differentiable binary mask centroid computation

Hello, I am wondering if there are ways to compute the centroid of a binary blob in an image such that they are differentiable? Essentially, I am looking for a differentiable twin of the loss function shown here:

class CentroidLoss(nn.Module):
    def __init__(self):
        super(CentroidLoss, self).__init__()
        
    def forward(self, rendered_silhouette, original_silhouette):
        rendered_sil_indices = torch.nonzero(rendered_silhouette)
        rendered_sil_centroid = torch.Tensor.float(rendered_sil_indices).mean(dim=0)

        original_sil_indices = torch.nonzero(original_silhouette)
        original_sil_centroid = torch.Tensor.float(original_sil_indices).mean(dim=0)

        centroid_dist = ((rendered_sil_centroid[0] - original_sil_centroid[0]) ** 2 \
                        + (rendered_sil_centroid[0] - original_sil_centroid[0]) ** 2 ) \
                        / (rendered_silhouette.shape[0]*rendered_silhouette.shape[1])

        return centroid_dist

“rendered_silhouette” and “original_silhouette” are binary images, each containing a blob.

Hi Shubham!

The short answer is to compute the center of mass (the “weighted
centroid”) of rendered_silhouette before thresholding.

Let me speculate as to what your use case might be:

Your original_silhouette plays the role of some kind of “ground-truth”
target. Perhaps your network is trying to reconstruct it somehow.

rendered_silhouette is generated by a network and you want to
train the weights of the network so that rendered_silhouette
approximates original_silhouette. Therefore you want a differentiable
loss function so that you can use backpropagation for training.

In all probability your network generates an image that is not binary.
That is, the value of each pixel in the generated image is the probability
(or perhaps the logit) for that pixel being 1 (rather than 0) in the final
binary image. You then threshold this “probability” image to get your
final rendered_silhouette binary image.

This thresholding is not (usefully) differentiable, so you can’t
backpropagate through it.

What you can do instead is calculate the center of mass of the (blob
in the) probability image. This will be a properly-differentiable version
of the (non-differentiable) centroid of the thresholded image. Use it
to compute centroid_dist returned by your CentroidLoss function.

You can now use this differentiable loss to backpropagate and train
the weights in your network.

Best.

K. Frank

Hi Frank,

Thanks for responding, yeah your speculation is correct. How do one go about computing the center of the mass though?

In essence, one need to multiply the indices with the pixel value and then take mean of it all. But I can’t think of a differentiable way of obtaining indices.

Thanks,
Shubham

Hi Shubham!

Yes, this is correct. More precisely, you need to take the
pixel-value-weighted mean of the pixel locations (indices)
to get the “center-of-mass” location.

You don’t need to get “differentiable” indices. What matters is that
your (pre-thresholding) pixel values are differentiable (with respect
to your network weights), and presumably they are. Your indices
(pixel locations) are fixed constants. They do not need to be – and
are not – differentiable.

Best.

K. Frank