I found a way to do it without a for
loop, combining the approaches of @tom, @ptrblck and with some support from friends,
import torch
from torch_scatter import scatter_max
class AdvMaxify(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, CCs):
indexes = torch.unique(CCs) # These are also unique connected component ids.
cc_maxes = inp.new_zeros(inp.shape[:2] + indexes.shape) # Tensor in [bs, ch, num_CCs] shape to keep each CC's max value.
cc_maxes, _ = scatter_max(src=inp.view(inp.shape[:2] + (-1,)),
index=CCs.unsqueeze(1).expand_as(inp).view(inp.shape[:2] + (-1,)),
out=cc_maxes)
mask_shape = indexes.shape + inp.shape
masks = CCs.unsqueeze(1).expand(mask_shape) == indexes.view(indexes.shape + (1, 1, 1, 1)).expand(mask_shape)
return torch.sum(masks * cc_maxes.permute(2, 0, 1).contiguous().view(mask_shape[:3] + (1, 1)).expand(mask_shape), dim=0)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
if __name__ == "__main__":
adv_maxify = AdvMaxify.apply
inp = torch.tensor([[[[ 8., 5., 10., 12.],
[11., 7., 10., 12.],
[10., 8., 12., 10.],
[ 8., 11., 10., 6.]],
[[ 7., 10., 6., 5.],
[ 7., 7., 5., 7.],
[12., 10., 5., 11.],
[ 7., 6., 7., 12.]],
[[ 9., 7., 5., 9.],
[ 7., 9., 12., 9.],
[10., 11., 11., 8.],
[10., 10., 6., 8.]]],
[[[10., 11., 11., 10.],
[ 8., 10., 12., 7.],
[ 8., 5., 7., 8.],
[ 6., 9., 6., 8.]],
[[ 9., 10., 7., 11.],
[ 8., 9., 7., 11.],
[ 8., 10., 6., 11.],
[10., 7., 7., 6.]],
[[12., 10., 10., 11.],
[12., 11., 10., 8.],
[12., 8., 10., 9.],
[12., 10., 10., 10.]]]], requires_grad=True)
cidx = torch.tensor([[[0, 1, 1, 1],
[2, 3, 3, 3],
[2, 3, 3, 3],
[2, 3, 3, 3]],
[[0, 1, 1, 2],
[0, 1, 1, 2],
[3, 4, 4, 5],
[3, 4, 4, 5]]])
reference_output = torch.tensor([[[[ 8., 12., 12., 12.],
[11., 12., 12., 12.],
[11., 12., 12., 12.],
[11., 12., 12., 12.]],
[[ 7., 10., 10., 10.],
[12., 12., 12., 12.],
[12., 12., 12., 12.],
[12., 12., 12., 12.]],
[[ 9., 9., 9., 9.],
[10., 12., 12., 12.],
[10., 12., 12., 12.],
[10., 12., 12., 12.]]],
[[[10., 12., 12., 10.],
[10., 12., 12., 10.],
[ 8., 9., 9., 8.],
[ 8., 9., 9., 8.]],
[[ 9., 10., 10., 11.],
[ 9., 10., 10., 11.],
[10., 10., 10., 11.],
[10., 10., 10., 11.]],
[[12., 11., 11., 11.],
[12., 11., 11., 11.],
[12., 10., 10., 10.],
[12., 10., 10., 10.]]]])
maxed = adv_maxify(inp, cidx)
print(torch.all(reference_output == maxed))
> tensor(True)