How filter tensor' rows by a float value

Hi, guys.

I have a tensor x = (a, b, c) , where a is batch, b are coord values and c is a float. I’d like to filter these values by threshold y. I’m doing this:

conf_mask = (prediction[:,:,4] > threshold_conf).float().unsqueeze(2)
prediction *= conf_mask
prediction = prediction[prediction[:,:,4] > 0]

But the tensor shapes are different (previous (a, b, c) after (b, c)). What am I wronging? In fact, I’d like to keep only the values from x that are over to threshold conf.

It looks like the post process of the detection task…

conf_mask = (prediction[..., 4] > threshod_conf)
predicton = prediction[..., conf_mask]

Yes, @Eta_C . Exactly. But, using your suggestion, I’m still getting an error in the second line (predicton = prediction[…, conf_mask]).

IndexError: The shape of the mask [1, 10647] at index 0 does not match the shape of the indexed tensor [1, 10647, 6] at index 1

Sorry, it should be

prediction = prediction[conf_mask, :]

Hi, Eta! Tks for your help. But it still doesn’t work… the output is the same.

This worked perfectly, thanks!