Some stupid question about tensor operations in Pytorch

Hi there, I ran into a problem during the recent project. In the “forward” part, the “for” I use slowed the program a lot which I cannot accept.

As the sample shows above, how could I convert the “for” loop into matrix operation in pytorch or einsum making the program run faster ? Or how to think about problem like this ?

Could you post an executable code snippet with the missing definitions of e.g. couplings_lookup by wrapping it into three backticks ```, please?

couplings_lookup = torch.tensor(range(1,901)).reshape(6,6,5,5)

fccl = torch.tensor([1,2,3,4,5,0])
xx = torch.tensor([3,2,2,2,1,0])
coupling = []
for idx, pos in enumerate(fccl):
    ccoupling_mask = torch.zeros(couplings_lookup.shape)
    ccoupling_mask[fccl[:idx],pos,xx[:idx],:] = 1
    # ccoupling_mask[fccl[:idx],pos,:,:] = 1
    # ccoupling_mask[:,pos,xx[:idx],:] = 1
    ccoupling = torch.sum(couplings_lookup * ccoupling_mask, axis=(0,1,2))

Maybe give me some hits about how to use “scatter” to accomplish this? Many thanks.

