Masking Rows of Tensor

I have a tensor, A of shape [X,Y,Z] and an indexing tensor, B, of [X,Y,1]. I would like to index A with B, and replace those elements of A with 0.

B is the output of a neural network, with the desired result that it is either true or false for each Y.

r1 = -10
r2 = 10  #some ranges

X = 20
Y=100
Z=5

activation = nn.Tanh() #an activation function
 
#A is an arbitrary random tensor [0,1]
A = torch.FloatTensor(X,Y,Z).uniform_(0, 1)

#B is an arbitrary random tensor [-1,1]
B = activation(torch.FloatTensor(X,Y,1).uniform_(r1, r2))

#idx are where B is less than 0
idx = (B<0)

#Some function to mask A with B such that where B is True, A=0
A = foo(A,B)

What is the correct function, foo, required such that A is both masked with zero and differentiable?