Optimization avoiding loops

Hello to everyone! I’m working with PyTorch 2.1. Given that s1 and s2 are two tensors of size (self.nrep, self.npt) with indices ranging from 0 to self.size-1, net is a tensor of size (self.size, self.size, K, 4), and conf is a tensor of size (self.nrep, self.npt, self.size), i would like to optmize this code avoiding the for loops:

net_conf = torch.zeros(size=(self.nrep, self.npt, K, 4), dtype=torch.complex128, device=f'cuda:{self.dev}')
for r in range(self.nrep):
    for k in range(self.npt):
        net_conf[r,k] = conf[r,k,net[s1[r,k],s2[r,k]]]

torch.func’s vmap or functorch’s torch.dim might be useful here