all = torch.zeros(batch, n, dim)
for B in range(batch):
wpe = torch.zeros(n,dim)
for i in range(n):
wpe[w][:] = torch.sum(atten_matrix[B][:][i] * ee.permute(0,2,1),dim=2)
# ee.permute(0,2,1).size() => batch x dim x L
# atten_matrix.size() batch x L x L
all[B] = wpe
Did you mean something like wpe[w][:] += ... or wpe[i][:] = ...?
It would be best if you could post a complete, runnable script
(including small sample data) that does what you want, and
also post the results. Then we could see for sure what you
are trying to do, and see if we can reproduce those results
with a loop-free approach.
Thanks Frank for the reply.
Sorry that was my bad, in wpe[w][:] = torch.sum(atten_matrix[B][:][i] * ee.permute(0,2,1),dim=2) w should be i. About wpe, atten_matrix[B][:][i] * ee.permute(0,2,1) gives me L vectors (L x dim) and I sum them up along the second (actually third) dimension (dim=2), which gives me one vector (1 x dim).