I am implementing a pairwise distance function and I can’t find an pytorch equivalent to do the following

```
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
distances = tensor.expand_dims(square_norm, 0) - 2.0 * dot_product + tensor.expand_dims(square_norm, 1)
```