Mask a Tensor based on the max value of a selected dimension


I have a Tensor of shape [n, m, m] ( n images each m*m ). I want to mask each image according to the max value of each row of the image. the final matrix has the same shape as the original one but the [m,m] is now mask matrices. what is the fastest way on GPU to do this?

You have comparison function that will give you a 0-1 ByteTensor that contains a mask.
For example:

import torch
t = torch.rand(5, 10, 10)
tmax = t.max(-1, keepdim=True)[0]

mask = # stritly greater
mask = # greater or equal