Merging classes of the output of the semantic segmentation module

Hi all,

My question is about the image segmentation task.
I have a tensor with the size of (batch_size, 150, height, width) and the second number(150) corresponds to the number of classes.

Now I want to merge those classes into 25 classes by summing their probabilities.
First of all, is this procedure correct? It seems it should be, for example I want to merge river and lake into one class.

Note: I have the source and target indexes to be merged. For example:

target=[1, 2]
source= [[3,4],
         [5,6]]

# here for making things clear, I use a 1-D matrix
input = [1, 3, 4, 2, 5, 7, 6]

expected_output = [1, 1, 1, 2, 2, 7, 2] 

My question is, what is the optimal way to do this process?

My attempt: First, I tried to make a one hot encoded vector of each source index should be converted to target. Then use these vectors as masks then using linear multiplication to get aggregated value. And this process happens at the same time of all pixels.

By the way, I am still unsure what is the efficient way to achieve the goal using PyTorch available functions.

I’m not sure, if I understand the use case correctly, but would scattering the output logits using a mapping work?
Here is a small example:

batch_size = 10
nb_classes = 2
nb_features = 5
nb_out = 10

x = torch.randn(batch_size, nb_features)
target = torch.randint(0, nb_classes, (batch_size,))

mapping = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

model = nn.Linear(nb_features, nb_out)
criterion = nn.CrossEntropyLoss()

output = model(x)
output_small = torch.zeros(batch_size, nb_classes).scatter_add(1, mapping.unsqueeze(0).expand_as(output), output)
loss = criterion(output_small, target)
loss.backward()
print(model.weight.grad)
1 Like

scatter_add() is exactly the function I was looking for.

By the way, I have pretrained model and I do not need to train it. I just want to merge some different classes if its prediction before using .max() function to get the corresponding class.

Thank you so much.

1 Like