Hello.
I was previously using several for loops for below use case, but it’s very very slow because I have bigger shapes of tensors. I saw the documentation of torch.where
, but the pattern of tensors is very irregular, so I am stuck.
I would be thankful if someone can provide a dummy example.
Use case:
I have two 2D tensors say A and B, both of same shape.
Tensor-A : say torch.randn(5,5)
is
[ 0.1923, 1.6150, -0.4331, 0.7061, 1.7127],
[-0.4912, 0.5317, -1.0820, -0.2575, 0.3446],
[ 0.3956, 0.7645, -0.7015, 0.0574, 0.6930],
[ 0.4492, 0.0215, -1.1855, -1.1453, 0.3912],
[-0.5674, -0.8794, -0.1316, 0.4391, -1.0830]
Tensor-B consists of only three unique integer values (say 1,2,3). Each row consists of some 1s, 2s and 3s but in different proportion. Say, Tensor-B by torch.randint(1,4,(5,5))
is
[3, 1, 1, 1, 2],
[3, 1, 2, 3, 3],
[3, 2, 2, 1, 1],
[2, 2, 3, 3, 1],
[2, 3, 3, 1, 1]
##########
What I want to do is,
For each row,
- Sum all the values of Tensor-A corresponding to the indices of all 1s of Tensor-B.
- Divide each value of Tensor-A corresponding to the indices of all 2s of Tensor-B by the sum from previous step.
- Take exponent after division for those values corresponding to all 2s and then sum again the values corresponding to all 2s for each row.
i.e for above tensors if I take 3rd rows of A and B:
0.0574+0.6930 = 0.7504
(0.7645/0.7504)= 1.019, (-0.7015/0.7504)= -0.9348
exp(1.019)+exp(-0.9348) = 3.163
Similarly after doing the same for each row: Row1’s value+ Row2’s+ 3.163+…