I have a tensor, A
of shape [X,Y,Z]
and an indexing tensor, B
, of [X,Y,1]
. I would like to index A
with B
, and replace those elements of A
with 0.
B
is the output of a neural network, with the desired result that it is either true or false for each Y
.
r1 = -10
r2 = 10 #some ranges
X = 20
Y=100
Z=5
activation = nn.Tanh() #an activation function
#A is an arbitrary random tensor [0,1]
A = torch.FloatTensor(X,Y,Z).uniform_(0, 1)
#B is an arbitrary random tensor [-1,1]
B = activation(torch.FloatTensor(X,Y,1).uniform_(r1, r2))
#idx are where B is less than 0
idx = (B<0)
#Some function to mask A with B such that where B is True, A=0
A = foo(A,B)
What is the correct function, foo
, required such that A
is both masked with zero and differentiable?