# Most efficient way to mask tensor with other tensor to replace values with 0s

I have a (3, h, w) rgb tensor and a (1, h, w) tensor. The latter tensor is sparse except for some > 0 values. I would like to set all the values in the rgb tensor to 0 except for the values whose indices correspond to (h, w) indices which contain non-zero values in the latter tensor.
I would like this to occur along the channel dimension so the result is a (3, h, w) rgb tensor where each channel becomes sparse except for the indices corresponding to the (1, h, w) tensor. How can I do this most efficiently?

Thank you.

Huh, does the second tensor contain values (as opposed to int indexes)? If so, the first tensor is not used, and second.expand(3,-1,-1) (+ .contiguous() if needed) would broadcast across channels. If I misunderstood, you likely need where() or scatter() ops to merge values.

Yes, the second tensor contains values rather than int indexes. I basically want to make the first tensor have 0s in the same places the second tensor has 0s. I am not sure why the first tensor wouldnt be used. Here is an example:
rgb tensor:
tensor([
[
[ 23, 54],
[ 1, 90]],
[
[ 22, 190],
[ 30, 10]
]])

Other tensor:
tensor([
[
[12, 0],
[ 0, 12]
]])

Resultant tensor:
tensor([
[
[ 23, 0],
[ 0, 90]],
[
[ 22, 0],
[ 0, 10]
]])

a * b.bool() is enough for that (cast to float manually if your torch version is very old)