Potential suggestion to speed up the following snippet?

Hi. I am new to this forum. The following code is extremely slow and I am seeking ways to speed it up. The distance tensor is with shape (batch_size, D, D). Literally, I want to obtain the average of every three points in the distance tensor. Is there any ways / native ops that can be used to speed this op up?

@torch.jit.script
def cal_3p_distance(distance):
    E = torch.zeros(distance.shape[0], distance.shape[1], distance.shape[1], distance.shape[1]).to(distance.device)
    for i in range(distance.shape[1]):
        for j in range(distance.shape[1]):
            for k in range(distance.shape[1]):
                E[:, i, j, k] = (distance[:, i, j] + distance[:, i, k] + distance[:, j, k])
    E = E / 3
    return E

You want to use broadcasting

E = distance[:, :, None] + distance[:, :, None, :] + distance[:, None, :, :] / 3 should do the trick.

Best regards

Thomas

1 Like