Hello all, I have an input X
with a size of BxCxHxW, where B is the batch size. A corresponding label Y
of the input size of BxHxW. I want to keep the pixel value of the input X
unchanged in the region which has Y>0
, and set the pixel value of the input X
in the region of Y<0
. I used the code below:
mask = (Y>0).float()
output= X* mask - (1 - mask)
It worked for batch size B=1
, but it does not work for batch size B>1
. How should I correct it? Thanks
I just think one solution using loop. But I think it is not perfect solution
output=X
for i in range (batch_size):
mask = (Y[i,...]>0).float()
output[i,...]= X[i,...]* mask - (1 - mask)