Any equivalent to tf.segment_max

Hi,

in tensorflow, we have tf.segment_max (https://www.tensorflow.org/api_docs/python/tf/segment_max)

but in pytorch, what is the equivalent way to do this?

thanks