How to perform segment max?

An important pooling like operation is to take elementwise maximum along an axis. For batched tensors with same size N, D1, D2, ..., torch.max is all we need. But when it comes to the case where the first dimension of tensors are different with N1, D1, D2, ..., N2, D1, D2, ..., …, we need segment max. Is there a recommended way to do segment max with PyTorch? I know TensorFlow has something like tf.math.unsorted_segment_max.

If not, for a workaround, is it more recommended to do a for loop + concatenation or padding? Thanks.

Pytorch does not have all the segmentation ops like tensorflow.
The best way to do this depends on how you have the tensors to start with. If they’re already in one Tensor with padding, then keep it that way. If you have a small number of them or widely different sizes , a for-loop will be better. If you have many with similar sizes, creating the padded Tensor might be interesting.

1 Like

Got it, thank you Alban!