Is there a function working as tf.segment_max() in pytorch?

As shown in the title