An important pooling like operation is to take elementwise maximum along an axis. For batched tensors with same size N, D1, D2, ...
, torch.max
is all we need. But when it comes to the case where the first dimension of tensors are different with N1, D1, D2, ...
, N2, D1, D2, ...
, …, we need segment max. Is there a recommended way to do segment max with PyTorch? I know TensorFlow has something like tf.math.unsorted_segment_max.
If not, for a workaround, is it more recommended to do a for loop + concatenation or padding? Thanks.