Hi. I am new to this forum. The following code is extremely slow and I am seeking ways to speed it up. The distance tensor is with shape (batch_size, D, D). Literally, I want to obtain the average of every three points in the distance tensor. Is there any ways / native ops that can be used to speed this op up?

```
@torch.jit.script
def cal_3p_distance(distance):
E = torch.zeros(distance.shape[0], distance.shape[1], distance.shape[1], distance.shape[1]).to(distance.device)
for i in range(distance.shape[1]):
for j in range(distance.shape[1]):
for k in range(distance.shape[1]):
E[:, i, j, k] = (distance[:, i, j] + distance[:, i, k] + distance[:, j, k])
E = E / 3
return E
```