Hi there, I ran into a problem during the recent project. In the “forward” part, the “for” I use slowed the program a lot which I cannot accept.
As the sample shows above, how could I convert the “for” loop into matrix operation in pytorch or einsum making the program run faster ? Or how to think about problem like this ?
Sorry for my poor English.
Thanks a lot!
Could you post an executable code snippet with the missing definitions of e.g. couplings_lookup
by wrapping it into three backticks ```, please?
couplings_lookup = torch.tensor(range(1,901)).reshape(6,6,5,5)
fccl = torch.tensor([1,2,3,4,5,0])
xx = torch.tensor([3,2,2,2,1,0])
coupling = []
for idx, pos in enumerate(fccl):
ccoupling_mask = torch.zeros(couplings_lookup.shape)
ccoupling_mask[fccl[:idx],pos,xx[:idx],:] = 1
# ccoupling_mask[fccl[:idx],pos,:,:] = 1
# ccoupling_mask[:,pos,xx[:idx],:] = 1
ccoupling = torch.sum(couplings_lookup * ccoupling_mask, axis=(0,1,2))
print(ccoupling)
coupling.append(ccoupling)
Thanks for the reply!
Maybe give me some hits about how to use “scatter” to accomplish this? Many thanks.
Sorry for the bother .