Largest connected component in loss function

I would like to add a loss function that only takes into account the largest connected component of the output of my network (a segmentation). My idea is that this will led the network to be less eager to disconnect small objects.

Is it possible with torch operations?
I already tried to detach and use numpy methods (skimage.label), but using numpy is not compatible with autograd.

Any suggestions? Thanks