Getting min before a max

Given a tensor of values with dimensions [batch, values], I need to get the minimum value that occurs before the maximum value. Thus, this naive approach:

value_max, value_max_idx = data.max(dim=1)
value_min_before_max, _ = data[:value_max_idx + 1].min(dim=1)

This obviously didn’t work as the slice is a tensor of dimension [batch]. But I feel like there should be a simple answer here… anyone have a good idea?

You could mask the data beyond the max with the +inf or the maximum integer for your dtype it before taking the min:

inf_tensor = torch.tensor(math.inf, device=data.device, dtype=data.dtype)
col_idx = torch.arange(data.size(1))
value_max, value_max_idx = data.max(dim=1)
masked_data = torch.where(col_idx[None, :] <= value_max_idx[:, None], data, inf_tensor)
value_min_before_max, _ = masked_data.min(dim=1)

Depending on your situation, overwriting the relevant data indices might be an alternative to avoid allocating masked_data as a tensor, but I generally advise against using inplace unless you absolutely need to.

Best regards