Comparing two tensors without using for loop

Hi all, I need to compare each row of a tensor with corresponding index of another tensor without for loop. Here is an example,

inp = tensor([[[[[ 6.,  2.,  5.,  6.],
           		 [ 3.,  4.,  7.,  8.]],

      			[[ 9., 10., 14., 14.],
           		 [11., 16., 16., 16.]]]]]) 

maxes = torch.tensor([[[[ 6.,  8.],
     				    [14., 16.]]]])

Comparison should return,

tensor([ True, False, False,  True],
       [False, False, False,  True],
       [False, False,  True,  True],
       [False,  True,  True,  True])

Any ideas on how to achieve that ?

I think you want torch.Tensor — PyTorch 1.7.0 documentation. You can expand maxes to the larger size of the full input and do a comparison

1 Like

try

result = (inp == maxes[..., None]).squeeze()
1 Like