Hello Youzunzhi!
I don’t know of any way to avoid a loop without summing over the
full tensor, such as by using .cumsum()
(or Zijing’s masked-sum
suggestion).
After forming the sums, you have to index into them to get the
desired partial sum. I don’t know of a “clean” way of doing this;
the best I could come up with is to use .take()
.
To accommodate the case in which an element of c
is 0
(so that
you sum over zero elements of t
, getting 0.0
), we “initialize” the
partial sums with 0.0
by prepending a zero slice to t
.
Here is my approach:
import torch
torch.__version__
t = torch.FloatTensor([[1.0000, 3.0000, 4.0000, 2.0000, 1.0000],
[2.0000, 3.0000, 3.5000, 4.0000, 3.5000]])
c = torch.LongTensor([2, 3])
t0 = torch.cat ((torch.zeros (t.shape[0], 1), t), 1) # initialize partial sums with 0
tsum = t0.cumsum (dim = 1) # calculate partial sums
cp = c + t0.shape[1] * torch.arange (c.shape[0]).long() # 1-d indices so we can use take()
sc = tsum.take (cp) # get specified partial sums
print ('sc =\n', sc) # print result
And here is the output:
>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> t = torch.FloatTensor([[1.0000, 3.0000, 4.0000, 2.0000, 1.0000],
... [2.0000, 3.0000, 3.5000, 4.0000, 3.5000]])
>>> c = torch.LongTensor([2, 3])
>>>
>>> t0 = torch.cat ((torch.zeros (t.shape[0], 1), t), 1) # initialize partial sums with 0
>>> tsum = t0.cumsum (dim = 1) # calculate partial sums
>>> cp = c + t0.shape[1] * torch.arange (c.shape[0]).long() # 1-d indices so we can use take()
>>> sc = tsum.take (cp) # get specified partial sums
>>>
>>> print ('sc =\n', sc) # print result
sc =
4.0000
8.5000
[torch.FloatTensor of size 2]
Best.
K. Frank