Forecast energy score: compute pairwise distances between tensor "columns"


I am trying to implement an energy score loss for a forecast ensemble.

I’m having trouble vectorizing the second term, which is essentially a sum over pairwise distances between different ensemble members (i != j) of the forecast tensor X. X.shape = (batch_size, n_vars, lat, lon, n_members) and the (i, j) index the n_members dimension (X.shape[-1]).

Here’s the “loopy” version i coded for a single sample:

    # works only with a batch_size == 1
    tmp_b = torch.zeros(1)
    for i in range(m):
        for j in range(m):
            if i == j:
            # i am using a more general formula where each 2-norm is raised to the power beta
            tmp_b += torch.pow(torch.linalg.norm((y_pred[..., i] - y_pred[..., j]).flatten(), ord=2), beta)

    # unbiased estimator: divide by (m-1) instead of m as the i = j terms cancel out
    tmp_b *= - 1.0 / (2 * m * (m - 1))

So I have two questions:

1/ How to vectorize the code above?
2/ How to vectorize the formula for batch_size > 1 (essentially this would entail averaging over the energy losses for all samples in the batch)


I was able to vectorize the single-sample batch code with the help of torch.cdist:

    tmp_b = - 1.0 / (2 * m * (m - 1)) * torch.sum(
                y_pred.reshape(-1, m).T, y_pred.reshape(-1, m).T, p=2
            ), beta

Question 2 still stands. Anybody know how to vectorize this op for batch sizes > 1? Thanks.

Hmm… I think this does it:

    tmp_b = - 1.0 / (2 * m * (m - 1)) * torch.mean(
                  y_pred_.reshape(bs, -1, m).permute(0,2,1), y_pred_.reshape(bs, -1, m).permute(0,2,1), p=2
              ), beta
          ), axis=(1,2),
        ), dim=0