Confused by torch.scatter

Hello everybody,
I have been struggling for a few days now to get the scatter function working as I want. I’m really sorry, I tried to understand the documentation but I failed. My usecase is the following :

I have a [1,height, width] sized tensor, it’s a semantic mask for a segmentation network. Every pixel takes a value between 0 and 6.

If I try to understand the documentation (I’m using the out of place version of scatter) :

  • the index tensor tells the scatter function at which index it’s supposed to move the values of the src tensor. In my case, I’m expected the scattered tensor to go from size [1,height, width] with values inbetween 0 and 6 to go to size [7, height, width] with only 2 values, 0 and 1. In my case, is the segmentation mask the index ? That’s what I was thinking because if a pixel takes the value 3, then It should stay in the same (x,y) place, but be moved to the third index along the 0th dimension of the new, scattered tensor.

  • The dim argument tells the function along which dimension to scatter the values. In my case, I think I’m pretty sure it’s 0 (channels first representation for torch.nn.BCELossWithLogits()).

  • I’m confused by the src argument. Should I make an additional torch.ones() tensor that contains the value that will be moved according to src ?

In the end, based on my understanding, I wrote the following code :
x = torch.scatter(dim = 0, index = semantic_mask, src = torch.ones(semantic_mask.size()))

with the size argument of torch.ones based on the fact that index and src should have the same size according to the documentation, but that doesn’t work. Overall it’s a bit stupid but I’m really stuck here.

Your workflow sounds right and this code snippet should work:

h, w = 3, 3
mask = torch.randint(0, 7, (1, h, w))

res = torch.zeros(7, h, w)
res.scatter_(0, mask, 1.)

Could you check it locally and verify, that the result contains the expected values?

Yes, it works fine, thanks a lot for your help !