How to mask/select specific channel out of segmentation output?

hello everyone,

I am using a segmentation mask, and it gives me output in shape of BxCxHxW, where C is the number of classes.
During my training I only want to optimize for class k where k can be a number between 0 to C. So basically I want to only update the model for that specific channel and I dont care about the other channels. in my optimization I use L1 loss to force the k channel to be equal a mask.
Currently, I select my channel via output[:,k] so I will have L1Loss(output[:,k],mask), where output[:,k] and mask both has shape of BxHxW.
But it seems like something is not working okay. in my output I see that the other channels also get updated , so I assume something is wrong.

I was wondering if someone has a suggestion on how to mask or select a specific channel of output in a correct way

How are you checking the “updated channels”? Assuming you are using an nn.Conv2d layer as the output layer, each class channel would be created by a separate filter, which would use all input channels by default. So depending what exactly you are checking the updates might be expected.

hmmm that is interesting, what do you mean by How are you checking the “updated channels” ?
I guess I am not, but I am also not sure how should I do it, can you please give me a hint?

here is a psedo code of my code:

prediction= F.sigmoid(self.deeplabv(inputs)['out']) # self.deeplab is the pytorch deeplab model
batch = prediction.shape[0]//2
output_1 = prediction[:batch]                 
output_2 = prediction[batch:2*batch]
output_1_1 = output_1[selecting_index_1].unsqueeze(1)
output_1_2 = output_2[selecting_index_2].unsqueeze(1)
output_1_1_1= self.model(torch.mul(inputs[:batch]         ,output_1_1 ) )
output_1_2_1= self.model(torch.mul(inputs[batch:2*batch]  ,output_1_2 ) )
loss = self.loss(output_1_1_1,mask_1) + self.loss(output_1_2_1,mask_2)

In your initial post you’ve mentioned:

so I wanted to double check what exactly you are checking how and what the expectations are.
The currently posted code snippet also doesn’t show any checks, so I’m unsure what the use case is and what seems to go wrong.

I see, so I was visualizing the output of the final channels, my expectation was that since I only select a specific channel maybe one of them should be similar to the mask, and the other should be different.
however, when I visualize them all of them look very similar.
as an example, if I send a cat image, most of the channels show the mask of the cat, while my expectation was to only see the channel that is correspond to the cat shows the cat. One thought was maybe it is because I am using sigmoid instead of softmax (?) but I was not sure if it still should be like that

I am not sure if I understand you correctly, so I am not sure if this is what you are looking for.

So the output of your model is of shape (B, C, H, W) , correct? And at each location, there are unnormalized probabilities corresponding to the prediction of each class? Which would be the normal output of, for example torchvision.models.segmentation.deeplabv3_resnet101()['out'].

If so, to then get the maximum prediction of each class, and then use it for a downstream task, you can do output_predictions = output.argmax(0).

So for example (example with 21 classes):

with torch.no_grad():
    output = model(input_batch)['out'][0]
output_predictions = output.argmax(0)

# create a color pallette, selecting a color for each class
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")

# plot the semantic segmentation predictions of 21 classes in each color
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)

import matplotlib.pyplot as plt