I have a tensor,
A of shape
[X,Y,Z] and an indexing tensor,
[X,Y,1]. I would like to index
B, and replace those elements of
A with 0.
B is the output of a neural network, with the desired result that it is either true or false for each
r1 = -10 r2 = 10 #some ranges X = 20 Y=100 Z=5 activation = nn.Tanh() #an activation function #A is an arbitrary random tensor [0,1] A = torch.FloatTensor(X,Y,Z).uniform_(0, 1) #B is an arbitrary random tensor [-1,1] B = activation(torch.FloatTensor(X,Y,1).uniform_(r1, r2)) #idx are where B is less than 0 idx = (B<0) #Some function to mask A with B such that where B is True, A=0 A = foo(A,B)
What is the correct function,
foo, required such that
A is both masked with zero and differentiable?