I’m trying to find groups of pixels within a mask as shown here:
Circles are initially generated with the code:
t = torch.zeros(3,200,200)
im = to_pil_image(t)
draw = ImageDraw.Draw(im)
for _ in range(3):
s, e = rint(180), rint(180)
draw.ellipse((s, e, s+20, e+20), fill = 'blue')
t = to_tensor(im)
Here is how I’m getting those groups of pixels:
def get_all_pixels_near(group, pixels, distance=2):
stacked = group.repeat(1, pixels.shape[0]).view(-1, pixels.shape[0], 2)
ds = stacked.sub(pixels).pow(2).sum(2).float().sqrt()
for i in ds:
group = torch.cat([pixels[i<distance], group])
return group.unique(dim=0)
def find_clusters(mask, masks=[], distance=2):
mask = mask.clone()
ungrouped_pixels = mask.nonzero()[:, 1:]
if len(ungrouped_pixels) < 1:
return masks
p = ungrouped_pixels[0]
prev_len = len(p)
near = get_all_pixels_near(p[None], ungrouped_pixels, distance)
while prev_len != len(near):
prev_len = len(near)
near = get_all_pixels_near(near, ungrouped_pixels)
m2 = mask.clone()
for i in near:
m2[0,i[0],i[1]] = .5
mask[0,i[0],i[1]] = 0
m2[m2!=.5] = 0
m2[m2==.5] = 1
return find_clusters(mask, masks+[m2])
Is there is a more efficient way to do this?