Select Tensor by Row

I have two tensors, one with data and one with some value I would like to threshold on, both along row C:

a = torch.Tensor(B,C,X,Y)
b = torch.Tensor(B,C)

I would like to find those rows of C in tensor a that fulfill a criterion based on tensor b. For example:

# Indices of b where b is zero
idx = (b==0)

The end result should look like:

bar(b)
foo(a)
a = torch.Tensor(B,C-N,X,Y)
#N is number of rows that do not match condition. 

I’m not completely understand the use case.
If you are using a threshold on b, you might only have a single True output in idx.
Indexing a with this idx tensor would return an output of [1, X, Y], not [B, 1, X, Y].

Could you explain a bit, how the threshold should be applied to b?
Currently it seems that you are setting the same values for all elements in dim0.

Suppose I have a tensor, a

a = torch.Tensor(B,C,X,Y)

I am trying to normalize the Y axis to its Z-score taken along statistics in X. However, some of the elements along the channel dimension, C, are degenerate.

This means I am only interested with those C elements that have a non-zero standard deviation along the axis X.

The way I achieved what I wanted was as follows:

# Find elements with a square-sum standard deviation along the X axis
temp = torch.norm(torch.std(a, axis=-2), dim=-1)
# Find the non-zero indices of this tensor
idx = temp.nonzero()
# Slice the tensor along the indices
a_sliced = a[idx[:,0], idx[:,1], :, :]
# Perform some calculations of a_sliced
a_sliced = foo(a_sliced)
# Slice in the results into a similarly sized tensor as the original tensor.
a_unsliced = torch.zeros_like(a)
a_unsliced[idx[:, 0], idx[:, 1], :, :] = a_unsliced