i have a question regarding the usage of the torch.scatter() function.
I want to construct a weights matrix weights (# [B, N, V]. B is batch size, N is number of points and V is the number of features for each point. )
Let’s say i have two tensors
a = # shape [B, N, k], where B is batch size, N is number of points, k is the index number within [0,V] to select feature. b = # shape [B, N, k], where B is batch size, N is number of points, k stores here the weights for selected feature.
I tried to use function torch.scatter():
weights.scatter_(index=a, dim=2, value=some_fix_value). By this operation i can only set one fixed value, but not the whole value tensor b, which contains all information at those location.
Can someone gives me a hint on how to do this properly?