Finding groups of pixels

I’m trying to find groups of pixels within a mask as shown here:

Circles are initially generated with the code:

t = torch.zeros(3,200,200)
im = to_pil_image(t)
draw = ImageDraw.Draw(im)

for _ in range(3):
    s, e = rint(180),  rint(180)
    
    draw.ellipse((s, e, s+20, e+20), fill = 'blue')
t = to_tensor(im)

Here is how I’m getting those groups of pixels:

def get_all_pixels_near(group, pixels, distance=2):
    stacked = group.repeat(1, pixels.shape[0]).view(-1, pixels.shape[0], 2)
    ds = stacked.sub(pixels).pow(2).sum(2).float().sqrt()
    for i in ds:
        group = torch.cat([pixels[i<distance], group])
    return group.unique(dim=0)

def find_clusters(mask, masks=[], distance=2):
    mask = mask.clone()
    ungrouped_pixels = mask.nonzero()[:, 1:]
    
    if len(ungrouped_pixels) < 1: 
        return masks
    
    p = ungrouped_pixels[0]

    prev_len = len(p)
    near = get_all_pixels_near(p[None], ungrouped_pixels, distance)

    while prev_len != len(near):
        prev_len = len(near)
        near = get_all_pixels_near(near, ungrouped_pixels)
    
    m2 = mask.clone()
    for i in near:
        m2[0,i[0],i[1]] = .5
        mask[0,i[0],i[1]] = 0
    m2[m2!=.5] = 0 
    m2[m2==.5] = 1
    return find_clusters(mask, masks+[m2])

Is there is a more efficient way to do this?