Array Index operation

In below code “tensor_a” is one dimensional tensor with all value 1. From “tensor_b” i want to find out the index for which all the channel values are greater than 100 and “3rd channel” value should be greater than value of “1st channel” and “2nd channel”. That index value in “tensor_a” should be replace with 0.
In below code, “tensor_a” first value satisfy that condition(as first value of each channel is 160,185,210 respectively) .so b[0,0,0,0] should be 0.

b=torch.ones(1,1,3,3)
a=torch.FloatTensor([[[[160, 56, 20],
          [54, 6, 97],
          [65, 119, 56]],

         [[185, 13, 90],
          [10, 6, 220],
          [67, 88, 4]],

         [[210, 92, 74],
          [48, 217, 86],
          [42, 12, 24]]]])

This should work:

a = torch.tensor([[[[160, 56, 20],
                    [54, 6, 97],
                    [65, 119, 56]],
                 
                   [[185, 13, 90],
                    [10, 6, 220],
                    [67, 88, 4]],

                   [[210, 92, 74],
                    [48, 217, 86],
                    [42, 12, 24]]]]).float()

idx = (a[:, 2] > a[:, 0]) & (a[:, 2] > a[:, 1]) & (a[:, 2] > 100.)
result = idx.float()
print(result)
> tensor([[[1., 0., 0.],
           [0., 1., 0.],
           [0., 0., 0.]]])
1 Like