Masking tensor with multiple masks without for loop

Hi all, I have a 5-dimensional tensor of masks and I want to create a 4-D tensor by masking a 4-D input with each of these masks. So far I could come up with a for loop on unbinded masks, but I’d appreciate if someone could suggest a way to eliminate the for loop since the number of masks is high (hundreds). Here is an example,

>>> import torch

>>> _ = torch.manual_seed(16)
>>> num_masks = 2
>>> inp = torch.randint(1, 5, (3, 4), dtype=torch.float)
>>> outp = torch.zeros(inp.shape)
>>> masks = torch.randn((num_masks,) + inp.shape) > 0
>>> for m in torch.unbind(masks, dim=0):
>>> 	outp[m] = inp[m]


>>> inp
tensor([[2., 3., 2., 2.],
        [2., 4., 1., 2.],
        [4., 1., 2., 1.]])

>>> masks
tensor([[[False,  True,  True,  True],
         [ True, False, False,  True],
         [ True, False,  True, False]],

        [[False,  True,  True,  True],
         [ True,  True, False, False],
         [False,  True,  True,  True]]])

>>> outp
tensor([[0., 3., 2., 2.],
        [2., 4., 0., 2.],
        [4., 1., 2., 1.]])

I don’t think there is a proper way to avoid the for loop, since you are sequentially overwriting out in:

outp[m] = inp[m]

I.e. outp and inp have a shape of [3, 4] while you are creating num_masks with the same shape.
In the first iteration the first mask would thus manipulate outp, in the second iteration this already changed outp tensor would then be again manipulated so you have a sequential dependency.

One potential way would be to “sum” all masks to a single one and perform this operation only once.

out = torch.zeros(inp.shape)
mask = masks.sum(dim=0).bool()
out[mask] = inp[mask]
out == outp

but it depends on your use case and if that’s really what you are trying to achieve.

Hi @ptrblck,

Summing the masks would not be an appropriate approach to this problem. Each mask represents a connected component (of arbitrary shape) in the image. I use them for a custom max pooling operation. So for each mask, all pixels of that CC are set to their maximum value.

The operation does not have to be sequential, but I could not think of a better way than the for loop, so any non-sequential approaches are especially welcome. Here is an example and I can post its whole code if needed but I’m basically stuck at the for loop mentioned in the 1st post.

>>> inp
tensor([[[[11., 11., 10., 12.],
          [10., 11., 11., 10.],
          [11., 12., 12., 12.],
          [10., 12., 11., 10.]]]], requires_grad=True)

>>> CCs
tensor([[0., 0., 0., 1.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])

>>> masks
tensor([[[[[ True,  True,  True, False],
           [False, False, False, False],
           [False, False, False, False],
           [False, False, False, False]]]],

        [[[[False, False, False,  True],
           [False, False, False, False],
           [False, False, False, False],
           [False, False, False, False]]]],

        [[[[False, False, False, False],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True],
           [ True,  True,  True,  True]]]]])

>>> output
tensor([[[[11., 11., 11., 12.],
          [12., 12., 12., 12.],
          [12., 12., 12., 12.],
          [12., 12., 12., 12.]]]])

Unfortunately, I don’t understand how your output tensor is created given the inp and masks (e.g. where is the 11. coming from in output[0, 0, 0, 2]).

In any case, using your manual loop approach and the “summed” mask approach would still yield the same result, as you are just applying the masks sequentially and could thus also just create a single mask:

# setup
inp = torch.tensor([[[[11., 11., 10., 12.],
                      [10., 11., 11., 10.],
                      [11., 12., 12., 12.],
                      [10., 12., 11., 10.]]]], requires_grad=True)

masks = torch.tensor([[[[[ True,  True,  True, False],
                         [False, False, False, False],
                         [False, False, False, False],
                         [False, False, False, False]]]],

                      [[[[False, False, False,  True],
                         [False, False, False, False],
                         [False, False, False, False],
                         [False, False, False, False]]]],

                      [[[[False, False, False, False],
                         [ True,  True,  True,  True],
                         [ True,  True,  True,  True],
                         [ True,  True,  True,  True]]]]])

# reference
reference_output = torch.tensor([[[[11., 11., 11., 12.],
                                   [12., 12., 12., 12.],
                                   [12., 12., 12., 12.],
                                   [12., 12., 12., 12.]]]])

# loop
loop_output = torch.zeros(inp.shape)
for m in torch.unbind(masks, dim=0):
 	loop_output[m] = inp[m]

print((loop_output == reference_output).all())
> tensor(False)

# summed
mask = masks.sum(dim=0).bool()
summed_output = torch.zeros(inp.shape)
summed_output[mask] = inp[mask]
print((summed_output==reference_output).all())
> tensor(False)
print((summed_output==loop_output).all())
> tensor(True)

Unfortunately, I don’t understand how your output tensor is created given the inp and masks (e.g. where is the 11. coming from in output[0, 0, 0, 2]).

Sorry for the poor explanation. Each value in CCs indicates which connected component that pixel belongs to. So in the above example, 3 pixels belong to 0th CC, 1 pixel belongs to 1st CC and the rest belong to 2nd CC.

For custom pooling, all pixels of a CC are set to their maximum value from the input image. These are 11. for the 0th CC, 12. for the 1st CC and again 12. for the 2nd CC.

The problem with the summed mask approach is that each pixel exists in a connected component. Therefore, masks.sum(dim=0).bool() results in a mask filled only with True values, and applying it to input will return input again. I added the operation, the for loop in forward is causing a bottleneck.

import torch

class AdvMaxify(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, inp, CCs):
        maxified = torch.ones(inp.shape, device=inp.device)
        cs = torch.unique(CCs)
        mask_shape = cs.shape + maxified.shape
        masks = CCs.expand(mask_shape) == cs.view(cs.shape + (1, 1, 1, 1)).expand(mask_shape)
        for m in torch.unbind(masks, dim=0): # bottleneck
            m_inp = inp * m
            dim_maxes = torch.amax(m_inp.view(inp.shape[0], inp.shape[1], -1).squeeze(), dim=-1)
            num_dims = m.dim() - dim_maxes.dim()
            new_shape = dim_maxes.shape + (1,) * num_dims if m.shape[0] > 1 else (1,) + dim_maxes.shape + (1,) * (num_dims - 1)
            maxified[m] = dim_maxes.view(new_shape).expand_as(m)[m]
        return maxified
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None


if __name__ == "__main__":

    inp = torch.tensor([[[[11., 11., 10., 12.],
                          [10., 11., 11., 10.],
                          [11., 12., 12., 12.],
                          [10., 12., 11., 10.]]]], requires_grad=True)

    CCs = torch.tensor([[0., 0., 0., 1.],
                        [2., 2., 2., 2.],
                        [2., 2., 2., 2.],
                        [2., 2., 2., 2.]])

    reference_output = torch.tensor([[[[11., 11., 11., 12.],
	                                   [12., 12., 12., 12.],
	                                   [12., 12., 12., 12.],
	                                   [12., 12., 12., 12.]]]])

    adv_maxify = AdvMaxify.apply
    outp = adv_maxify(inp, CCs)

print((outp == reference_output).all())
> tensor(True)

If the components are disjoint and you get the numbers in a single integer tensor, you could first use scatter_max from the PyTorch scatter library and then indexing with the mask to get the values back into the index.
There are faster ways to do this, but at least the one I implemented isn’t open source.

Best regards

Thomas

Hi @tom, yes, the connected components are disjoint.

and then indexing with the mask to get the values back into the index.

Could you elaborate this part a bit more please ?

Update: This is what I could come up from your approach, but it still has the same bottleneck for loop,

import torch
from torch_scatter import scatter_max

inp = torch.tensor([[[[11., 11., 10., 12.],
                      [10., 11., 11., 10.],
                      [11., 12., 12., 12.],
                      [10., 12., 11., 10.]]]], requires_grad=True)

CCs = torch.tensor([[0, 0, 0, 1],
			        [2, 2, 2, 2],
			        [2, 2, 2, 2],
			        [2, 2, 2, 2]])

reference_output = torch.tensor([[[[11., 11., 11., 12.],
                                   [12., 12., 12., 12.],
                                   [12., 12., 12., 12.],
                                   [12., 12., 12., 12.]]]])


indexes = torch.unique(CCs)
out = inp.new_zeros(indexes.shape)
out, argmax = scatter_max(inp.view(-1), CCs.view(-1), out=out)
maxified = inp.new_ones(inp.shape)

for i in range(len(indexes)):
    mask = indexes[i] == CCs
    maxified[mask.view((1, 1) + mask.shape)] = out[i]

print((maxified == reference_output).all())
> tensor(True)

I found a way to do it without a for loop, combining the approaches of @tom, @ptrblck and with some support from friends,

import torch
from torch_scatter import scatter_max

class AdvMaxify(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, inp, CCs):
        indexes = torch.unique(CCs)                                                     # These are also unique connected component ids. 
        cc_maxes = inp.new_zeros(inp.shape[:2] + indexes.shape)                         # Tensor in [bs, ch, num_CCs] shape to keep each CC's max value. 
        cc_maxes, _ = scatter_max(src=inp.view(inp.shape[:2] + (-1,)),
                                  index=CCs.unsqueeze(1).expand_as(inp).view(inp.shape[:2] + (-1,)),
                                  out=cc_maxes)
        mask_shape = indexes.shape + inp.shape
        masks = CCs.unsqueeze(1).expand(mask_shape) == indexes.view(indexes.shape + (1, 1, 1, 1)).expand(mask_shape)
        return torch.sum(masks * cc_maxes.permute(2, 0, 1).contiguous().view(mask_shape[:3] + (1, 1)).expand(mask_shape), dim=0)


    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

if __name__ == "__main__":
    adv_maxify = AdvMaxify.apply
    inp = torch.tensor([[[[ 8.,  5., 10., 12.],
                          [11.,  7., 10., 12.],
                          [10.,  8., 12., 10.],
                          [ 8., 11., 10.,  6.]],
    
                         [[ 7., 10.,  6.,  5.],
                          [ 7.,  7.,  5.,  7.],
                          [12., 10.,  5., 11.],
                          [ 7.,  6.,  7., 12.]],
    
                         [[ 9.,  7.,  5.,  9.],
                          [ 7.,  9., 12.,  9.],
                          [10., 11., 11.,  8.],
                          [10., 10.,  6.,  8.]]],
    
                        [[[10., 11., 11., 10.],
                          [ 8., 10., 12.,  7.],
                          [ 8.,  5.,  7.,  8.],
                          [ 6.,  9.,  6.,  8.]],
    
                         [[ 9., 10.,  7., 11.],
                          [ 8.,  9.,  7., 11.],
                          [ 8., 10.,  6., 11.],
                          [10.,  7.,  7.,  6.]],
    
                         [[12., 10., 10., 11.],
                          [12., 11., 10.,  8.],
                          [12.,  8., 10.,  9.],
                          [12., 10., 10., 10.]]]], requires_grad=True)
    
    cidx = torch.tensor([[[0, 1, 1, 1],
                          [2, 3, 3, 3],
                          [2, 3, 3, 3],
                          [2, 3, 3, 3]],
    
                         [[0, 1, 1, 2],
                          [0, 1, 1, 2],
                          [3, 4, 4, 5],
                          [3, 4, 4, 5]]])
    
    reference_output = torch.tensor([[[[ 8., 12., 12., 12.],
                                       [11., 12., 12., 12.],
                                       [11., 12., 12., 12.],
                                       [11., 12., 12., 12.]],
     
                                      [[ 7., 10., 10., 10.],
                                       [12., 12., 12., 12.],
                                       [12., 12., 12., 12.],
                                       [12., 12., 12., 12.]],
     
                                      [[ 9.,  9.,  9.,  9.],
                                       [10., 12., 12., 12.],
                                       [10., 12., 12., 12.],
                                       [10., 12., 12., 12.]]],
     
                                     [[[10., 12., 12., 10.],
                                       [10., 12., 12., 10.],
                                       [ 8.,  9.,  9.,  8.],
                                       [ 8.,  9.,  9.,  8.]],
     
                                      [[ 9., 10., 10., 11.],
                                       [ 9., 10., 10., 11.],
                                       [10., 10., 10., 11.],
                                       [10., 10., 10., 11.]],
     
                                      [[12., 11., 11., 11.],
                                       [12., 11., 11., 11.],
                                       [12., 10., 10., 10.],
                                       [12., 10., 10., 10.]]]]) 
    
    maxed = adv_maxify(inp, cidx)
    
print(torch.all(reference_output == maxed))
> tensor(True)
1 Like

Good job. Congratulations!

1 Like