I have a (3, h, w) rgb tensor and a (1, h, w) tensor. The latter tensor is sparse except for some > 0 values. I would like to set all the values in the rgb tensor to 0 except for the values whose indices correspond to (h, w) indices which contain non-zero values in the latter tensor.
I would like this to occur along the channel dimension so the result is a (3, h, w) rgb tensor where each channel becomes sparse except for the indices corresponding to the (1, h, w) tensor. How can I do this most efficiently?
Huh, does the second tensor contain values (as opposed to int indexes)? If so, the first tensor is not used, and second.expand(3,-1,-1) (+ .contiguous() if needed) would broadcast across channels. If I misunderstood, you likely need where() or scatter() ops to merge values.
Yes, the second tensor contains values rather than int indexes. I basically want to make the first tensor have 0s in the same places the second tensor has 0s. I am not sure why the first tensor wouldnt be used. Here is an example:
[ 23, 54],
[ 1, 90]],
[ 22, 190],
[ 30, 10]
[ 0, 12]
[ 23, 0],
[ 0, 90]],
[ 22, 0],
[ 0, 10]
a * b.bool() is enough for that (cast to float manually if your torch version is very old)