I have two tensors, one with data and one with some value I would like to threshold on, both along row C
:
a = torch.Tensor(B,C,X,Y)
b = torch.Tensor(B,C)
I would like to find those rows of C
in tensor a
that fulfill a criterion based on tensor b
. For example:
# Indices of b where b is zero
idx = (b==0)
The end result should look like:
bar(b)
foo(a)
a = torch.Tensor(B,C-N,X,Y)
#N is number of rows that do not match condition.
I’m not completely understand the use case.
If you are using a threshold on b
, you might only have a single True
output in idx
.
Indexing a
with this idx
tensor would return an output of [1, X, Y]
, not [B, 1, X, Y]
.
Could you explain a bit, how the threshold should be applied to b
?
Currently it seems that you are setting the same values for all elements in dim0.
Suppose I have a tensor, a
a = torch.Tensor(B,C,X,Y)
I am trying to normalize the Y
axis to its Z-score taken along statistics in X
. However, some of the elements along the channel dimension, C
, are degenerate.
This means I am only interested with those C
elements that have a non-zero standard deviation along the axis X
.
The way I achieved what I wanted was as follows:
# Find elements with a square-sum standard deviation along the X axis
temp = torch.norm(torch.std(a, axis=-2), dim=-1)
# Find the non-zero indices of this tensor
idx = temp.nonzero()
# Slice the tensor along the indices
a_sliced = a[idx[:,0], idx[:,1], :, :]
# Perform some calculations of a_sliced
a_sliced = foo(a_sliced)
# Slice in the results into a similarly sized tensor as the original tensor.
a_unsliced = torch.zeros_like(a)
a_unsliced[idx[:, 0], idx[:, 1], :, :] = a_unsliced