Hi all, I have a 5-dimensional tensor of masks and I want to create a 4-D tensor by masking a 4-D input with each of these masks. So far I could come up with a for loop on unbinded masks, but I’d appreciate if someone could suggest a way to eliminate the for loop since the number of masks is high (hundreds). Here is an example,

``````>>> import torch

>>> _ = torch.manual_seed(16)
>>> inp = torch.randint(1, 5, (3, 4), dtype=torch.float)
>>> outp = torch.zeros(inp.shape)
>>> for m in torch.unbind(masks, dim=0):
>>> 	outp[m] = inp[m]

>>> inp
tensor([[2., 3., 2., 2.],
[2., 4., 1., 2.],
[4., 1., 2., 1.]])

tensor([[[False,  True,  True,  True],
[ True, False, False,  True],
[ True, False,  True, False]],

[[False,  True,  True,  True],
[ True,  True, False, False],
[False,  True,  True,  True]]])

>>> outp
tensor([[0., 3., 2., 2.],
[2., 4., 0., 2.],
[4., 1., 2., 1.]])
``````

I don’t think there is a proper way to avoid the for loop, since you are sequentially overwriting `out` in:

``````outp[m] = inp[m]
``````

I.e. `outp` and `inp` have a shape of `[3, 4]` while you are creating `num_masks` with the same shape.
In the first iteration the first mask would thus manipulate `outp`, in the second iteration this already changed `outp` tensor would then be again manipulated so you have a sequential dependency.

One potential way would be to “sum” all masks to a single one and perform this operation only once.

``````out = torch.zeros(inp.shape)
out == outp
``````

but it depends on your use case and if that’s really what you are trying to achieve.

Hi @ptrblck,

Summing the masks would not be an appropriate approach to this problem. Each mask represents a connected component (of arbitrary shape) in the image. I use them for a custom max pooling operation. So for each mask, all pixels of that CC are set to their maximum value.

The operation does not have to be sequential, but I could not think of a better way than the for loop, so any non-sequential approaches are especially welcome. Here is an example and I can post its whole code if needed but I’m basically stuck at the for loop mentioned in the 1st post.

``````>>> inp
tensor([[[[11., 11., 10., 12.],
[10., 11., 11., 10.],
[11., 12., 12., 12.],

>>> CCs
tensor([[0., 0., 0., 1.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])

tensor([[[[[ True,  True,  True, False],
[False, False, False, False],
[False, False, False, False],
[False, False, False, False]]]],

[[[[False, False, False,  True],
[False, False, False, False],
[False, False, False, False],
[False, False, False, False]]]],

[[[[False, False, False, False],
[ True,  True,  True,  True],
[ True,  True,  True,  True],
[ True,  True,  True,  True]]]]])

>>> output
tensor([[[[11., 11., 11., 12.],
[12., 12., 12., 12.],
[12., 12., 12., 12.],
[12., 12., 12., 12.]]]])
``````

Unfortunately, I don’t understand how your `output` tensor is created given the `inp` and `masks` (e.g. where is the `11.` coming from in `output[0, 0, 0, 2]`).

In any case, using your manual loop approach and the “summed” mask approach would still yield the same result, as you are just applying the masks sequentially and could thus also just create a single mask:

``````# setup
inp = torch.tensor([[[[11., 11., 10., 12.],
[10., 11., 11., 10.],
[11., 12., 12., 12.],

masks = torch.tensor([[[[[ True,  True,  True, False],
[False, False, False, False],
[False, False, False, False],
[False, False, False, False]]]],

[[[[False, False, False,  True],
[False, False, False, False],
[False, False, False, False],
[False, False, False, False]]]],

[[[[False, False, False, False],
[ True,  True,  True,  True],
[ True,  True,  True,  True],
[ True,  True,  True,  True]]]]])

# reference
reference_output = torch.tensor([[[[11., 11., 11., 12.],
[12., 12., 12., 12.],
[12., 12., 12., 12.],
[12., 12., 12., 12.]]]])

# loop
loop_output = torch.zeros(inp.shape)
loop_output[m] = inp[m]

print((loop_output == reference_output).all())
> tensor(False)

# summed
summed_output = torch.zeros(inp.shape)
print((summed_output==reference_output).all())
> tensor(False)
print((summed_output==loop_output).all())
> tensor(True)
``````

Unfortunately, I don’t understand how your `output` tensor is created given the `inp` and `masks` (e.g. where is the `11.` coming from in `output[0, 0, 0, 2]`).

Sorry for the poor explanation. Each value in `CCs` indicates which connected component that pixel belongs to. So in the above example, 3 pixels belong to 0th CC, 1 pixel belongs to 1st CC and the rest belong to 2nd CC.

For custom pooling, all pixels of a CC are set to their maximum value from the input image. These are `11.` for the 0th CC, `12.` for the 1st CC and again `12.` for the 2nd CC.

The problem with the summed mask approach is that each pixel exists in a connected component. Therefore, `masks.sum(dim=0).bool()` results in a `mask` filled only with `True` values, and applying it to `input` will return `input` again. I added the operation, the for loop in `forward` is causing a bottleneck.

``````import torch

@staticmethod
def forward(ctx, inp, CCs):
maxified = torch.ones(inp.shape, device=inp.device)
cs = torch.unique(CCs)
for m in torch.unbind(masks, dim=0): # bottleneck
m_inp = inp * m
dim_maxes = torch.amax(m_inp.view(inp.shape[0], inp.shape[1], -1).squeeze(), dim=-1)
num_dims = m.dim() - dim_maxes.dim()
new_shape = dim_maxes.shape + (1,) * num_dims if m.shape[0] > 1 else (1,) + dim_maxes.shape + (1,) * (num_dims - 1)
maxified[m] = dim_maxes.view(new_shape).expand_as(m)[m]
return maxified

@staticmethod

if __name__ == "__main__":

inp = torch.tensor([[[[11., 11., 10., 12.],
[10., 11., 11., 10.],
[11., 12., 12., 12.],

CCs = torch.tensor([[0., 0., 0., 1.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])

reference_output = torch.tensor([[[[11., 11., 11., 12.],
[12., 12., 12., 12.],
[12., 12., 12., 12.],
[12., 12., 12., 12.]]]])

print((outp == reference_output).all())
> tensor(True)
``````

If the components are disjoint and you get the numbers in a single integer tensor, you could first use `scatter_max` from the PyTorch scatter library and then indexing with the mask to get the values back into the index.
There are faster ways to do this, but at least the one I implemented isn’t open source.

Best regards

Thomas

Hi @tom, yes, the connected components are disjoint.

and then indexing with the mask to get the values back into the index.

Could you elaborate this part a bit more please ?

Update: This is what I could come up from your approach, but it still has the same bottleneck `for` loop,

``````import torch
from torch_scatter import scatter_max

inp = torch.tensor([[[[11., 11., 10., 12.],
[10., 11., 11., 10.],
[11., 12., 12., 12.],

CCs = torch.tensor([[0, 0, 0, 1],
[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2]])

reference_output = torch.tensor([[[[11., 11., 11., 12.],
[12., 12., 12., 12.],
[12., 12., 12., 12.],
[12., 12., 12., 12.]]]])

indexes = torch.unique(CCs)
out = inp.new_zeros(indexes.shape)
out, argmax = scatter_max(inp.view(-1), CCs.view(-1), out=out)
maxified = inp.new_ones(inp.shape)

for i in range(len(indexes)):

print((maxified == reference_output).all())
> tensor(True)
``````

I found a way to do it without a `for` loop, combining the approaches of @tom, @ptrblck and with some support from friends,

``````import torch
from torch_scatter import scatter_max

@staticmethod
def forward(ctx, inp, CCs):
indexes = torch.unique(CCs)                                                     # These are also unique connected component ids.
cc_maxes = inp.new_zeros(inp.shape[:2] + indexes.shape)                         # Tensor in [bs, ch, num_CCs] shape to keep each CC's max value.
cc_maxes, _ = scatter_max(src=inp.view(inp.shape[:2] + (-1,)),
index=CCs.unsqueeze(1).expand_as(inp).view(inp.shape[:2] + (-1,)),
out=cc_maxes)

@staticmethod

if __name__ == "__main__":
inp = torch.tensor([[[[ 8.,  5., 10., 12.],
[11.,  7., 10., 12.],
[10.,  8., 12., 10.],
[ 8., 11., 10.,  6.]],

[[ 7., 10.,  6.,  5.],
[ 7.,  7.,  5.,  7.],
[12., 10.,  5., 11.],
[ 7.,  6.,  7., 12.]],

[[ 9.,  7.,  5.,  9.],
[ 7.,  9., 12.,  9.],
[10., 11., 11.,  8.],
[10., 10.,  6.,  8.]]],

[[[10., 11., 11., 10.],
[ 8., 10., 12.,  7.],
[ 8.,  5.,  7.,  8.],
[ 6.,  9.,  6.,  8.]],

[[ 9., 10.,  7., 11.],
[ 8.,  9.,  7., 11.],
[ 8., 10.,  6., 11.],
[10.,  7.,  7.,  6.]],

[[12., 10., 10., 11.],
[12., 11., 10.,  8.],
[12.,  8., 10.,  9.],

cidx = torch.tensor([[[0, 1, 1, 1],
[2, 3, 3, 3],
[2, 3, 3, 3],
[2, 3, 3, 3]],

[[0, 1, 1, 2],
[0, 1, 1, 2],
[3, 4, 4, 5],
[3, 4, 4, 5]]])

reference_output = torch.tensor([[[[ 8., 12., 12., 12.],
[11., 12., 12., 12.],
[11., 12., 12., 12.],
[11., 12., 12., 12.]],

[[ 7., 10., 10., 10.],
[12., 12., 12., 12.],
[12., 12., 12., 12.],
[12., 12., 12., 12.]],

[[ 9.,  9.,  9.,  9.],
[10., 12., 12., 12.],
[10., 12., 12., 12.],
[10., 12., 12., 12.]]],

[[[10., 12., 12., 10.],
[10., 12., 12., 10.],
[ 8.,  9.,  9.,  8.],
[ 8.,  9.,  9.,  8.]],

[[ 9., 10., 10., 11.],
[ 9., 10., 10., 11.],
[10., 10., 10., 11.],
[10., 10., 10., 11.]],

[[12., 11., 11., 11.],
[12., 11., 11., 11.],
[12., 10., 10., 10.],
[12., 10., 10., 10.]]]])