Advanced Indexing (Medoid)

I have a tensor of shape torch.Size([400, 32, 32])

Each 32x32 patch contains only ones and zeros.

I want to return a geometric center / medoid of foreground pixels (1s).

The shape of the return value would be [400, 2].

Example:

The above [4, 4] patch returns [1,2] as the coordinate.

How to do that without a loop?

Thank you.

Hi Assefa!

If the mean location of the foreground pixels would be adequate for your
use case, you could do something like:

>>> import torch
>>> torch.__version__
'1.12.0'
>>> t = torch.tensor ([[0, 1, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 1, 1]])
>>> t.nonzero().mean (dim = 0, dtype = torch.float).round().long()
tensor([1, 2])

If the above doesn’t work for your use case, could you illustrate the specific
computation you want with a self-contained, runnable script that uses loops,
as necessary?

Best.

K. Frank

Hello.

Your solution works for 2d data only.

My input’s shape is [N, 32, 32].

The solution with a loop would be something like,

t = torch.rand((400, 32, 32)) > 0
mid = torch.stack([p.nonzero().mean (dim = 0, dtype = torch.float).round().long() for p in t])
mid.shape

Thank you.

Hi Assefa!

To do this on a “batch” basis, probably the simplest approach is to
generate a tensor that hold the indices of all of the pixels in your 32x32
slices and then compute the foreground-pixel weighted average of those
indices:

>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> t = torch.multinomial (torch.tensor ([249.0, 1.0]), 10 * 32 * 32, True).reshape (10, 32, 32)
>>> t.shape
torch.Size([10, 32, 32])
>>>
>>> mid = torch.stack([p.nonzero().mean (dim = 0, dtype = torch.float).round().long() for p in t])
>>>
>>> indices = torch.stack ((torch.arange (32).expand (32, 32).T, torch.arange (32).expand (32, 32)), dim = 0)
>>> midB = ((t.unsqueeze (1) * indices).sum (dim = (2, 3)) / t.sum (dim = (1, 2)).unsqueeze (1)).round().long()
>>>
>>> midB
tensor([[ 6, 18],
        [13, 17],
        [20, 18],
        [ 3,  6],
        [15, 15],
        [17, 11],
        [12, 24],
        [ 9, 20],
        [16, 17],
        [17, 17]])
>>>
>>> torch.equal (mid, midB)
True

Note, if a 32x32 slice happens to contain no foreground pixels, you will get
nan (before converting to long()) for the mean foreground pixel location
for that slice. That’s probably as reasonable an “undefined” or “special-case”
value as any.

Best.

K. Frank