How to perform converting mask on batch size?

Hello all, I have an input X with a size of BxCxHxW, where B is the batch size. A corresponding label Y of the input size of BxHxW. I want to keep the pixel value of the input X unchanged in the region which has Y>0, and set the pixel value of the input X in the region of Y<0. I used the code below:

mask = (Y>0).float()
output= X* mask -  (1 - mask)

It worked for batch size B=1, but it does not work for batch size B>1. How should I correct it? Thanks

I just think one solution using loop. But I think it is not perfect solution

for i in range (batch_size):
    mask = (Y[i,...]>0).float()
    output[i,...]= X[i,...]* mask -  (1 - mask)

Try to unsqueeze Y in dim1:

B, C, H, W = 2, 3, 4, 4
x = torch.randn(B, C, H, W)
y = torch.randn(B, H, W)

mask = (y.unsqueeze(1)>0).float()
output = x * mask - (1 - mask)

Would that work for you?

1 Like

Perfect! It worked…