# How to calculate distances between classes

Hello, everyone

I have the next code:

``````feature = torch.randn(10, 20)
target = torch.randint(0, 4, (10,))
total = 0
for i in range(0, 4):
for j in range(i + 1, 4):
diff = feature[target == i].mean(0) - feature[target == j].mean(0)
total += torch.norm(diff, 2)

``````

Can you help me to optimise it? I feel, it can be done without `for` loops, but I don’t understand how to get rid of them

Hi Podi!

You need to address two issues to avoid explicit loops:

First, you have to rework `feature[target == i]`. This is because the
`shape` of `feature[target == i]` depends on the value of `i`. Because
pytorch does not support ragged tensors (tensors whose slices don’t
have the same shape), you won’t be able to replace this version of the
loop with operations on a single tensor.

Therefore we will mask `feature` rather than indexing into it.

Second your `i` and `j` loops index over the upper triangle of the matrix
you would get if you wrote `diff` out in matrix form. Pytorch does not
(directly) support triangular matrices.

Therefore we will compute the full square `diff` matrix (which will be
symmetric, so there will be some redundancy, but that’s okay), and use
`triu()` to give us just the elements we want to sum over.

Thus:

``````>>> import torch
>>> print (torch.__version__)
1.10.2
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> feature = torch.randn (10, 20)
>>> target = torch.randint (0, 4, (10,))
>>>
>>> total = 0
>>> for i in range(0, 4):
...   for j in range(i + 1, 4):
...     diff = feature[target == i].mean(0) - feature[target == j].mean(0)
...     total += torch.norm(diff, 2)
...
>>> total
tensor(32.2318)
>>>
>>> mask = target.unsqueeze (-1) == torch.arange (4)
>>> means = ((feature.unsqueeze (1) * mask.unsqueeze (-1)).sum (0) / mask.sum (0).unsqueeze (-1))
>>> totalB = torch.norm (means.unsqueeze (0) - means.unsqueeze (1), 2, dim = 2).triu().sum()
>>> totalB
tensor(32.2318)
``````

As an aside, if `target == i` sums to zero for any value of `i`, `.mean()` won’t
really be defined, and you will get `nan`:

``````>>> target
tensor([3, 0, 0, 0, 2, 3, 3, 1, 0, 1])
>>> target = 1
>>> (target == 2).sum()
tensor(0)
>>> total = 0
>>> for i in range(0, 4):
...   for j in range(i + 1, 4):
...     diff = feature[target == i].mean(0) - feature[target == j].mean(0)
...     total += torch.norm(diff, 2)
...
>>> total
tensor(nan)
>>> mask = target.unsqueeze (-1) == torch.arange (4)
>>> means = ((feature.unsqueeze (1) * mask.unsqueeze (-1)).sum (0) / mask.sum (0).unsqueeze (-1))
>>> totalB = torch.norm (means.unsqueeze (0) - means.unsqueeze (1), 2, dim = 2).triu().sum()
>>> totalB
tensor(nan)
``````

Best.

K. Frank

3 Likes