Usage of torch.scatter() for multi-dimensional value

Hi!

i have a question regarding the usage of the torch.scatter() function.

I want to construct a weights matrix weights (# [B, N, V]. B is batch size, N is number of points and V is the number of features for each point. )

Let’s say i have two tensors

a = # shape [B, N, k], where B is batch size, N is number of points, k is the index number within [0,V] to select feature.
b = # shape [B, N, k], where B is batch size, N is number of points, k stores here the weights for selected feature.

I tried to use function torch.scatter():
weights.scatter_(index=a, dim=2, value=some_fix_value). By this operation i can only set one fixed value, but not the whole value tensor b, which contains all information at those location.

Can someone gives me a hint on how to do this properly?

Direct indexing should work if I understand your use case correctly:

B, N, V = 2, 3, 4

a = torch.randn(B, N, V)
k = 2
idx = torch.randint(0, V, (B, N, k))

b = torch.zeros(B, N, V)

r = torch.arange(b.size(0))[:, None, None]
c = torch.arange(b.size(1))[None, :, None]

b[r, c, idx] = a[r, c, idx]
print(idx)
# tensor([[[0, 1],
#          [2, 1],
#          [1, 2]],

#         [[0, 3],
#          [0, 3],
#          [3, 1]]])
print(a)
# tensor([[[ 1.0250, -1.8356,  0.4314,  2.1630],
#          [-0.8884, -0.2196, -1.5033,  0.8229],
#          [-0.1390, -0.9114, -0.2310, -0.7310]],

#         [[ 0.7332, -1.5779, -0.4527, -1.6785],
#          [ 0.6614, -0.0094, -0.1890,  0.3890],
#          [-0.4206, -1.1668, -0.1563, -0.6945]]])
print(b)
# tensor([[[ 1.0250, -1.8356,  0.0000,  0.0000],
#          [ 0.0000, -0.2196, -1.5033,  0.0000],
#          [ 0.0000, -0.9114, -0.2310,  0.0000]],

#         [[ 0.7332,  0.0000,  0.0000, -1.6785],
#          [ 0.6614,  0.0000,  0.0000,  0.3890],
#          [ 0.0000, -1.1668,  0.0000, -0.6945]]])

Hi @ptrblck,

thank you! It is generally what i mean, but the matrix a in my case is # [B, N, k] instead of # [B, N, V]. I only know the value at k-th location, but i want to assign 0 to all other unselected locations. In other words, i am turning matrix a from #[B, N, k] to #[B, N, V], with original value in k-index and 0 in other location. Do you have an idea?

Isn’t this exactly what my code is doing?
It selects values from a matrix a in the shape [B, N, V] using indices in [B, N, k] and assigns these to b, which is initialized with zeros for all other values.
If not, could you post a slow reference implementation, please?

@ptrblck Thank you, you are right.