How to calculate distances between classes

Hello, everyone

I have the next code:

feature = torch.randn(10, 20)
target = torch.randint(0, 4, (10,))
total = 0
for i in range(0, 4):
  for j in range(i + 1, 4):
    diff = feature[target == i].mean(0) - feature[target == j].mean(0)
    total += torch.norm(diff, 2)

Can you help me to optimise it? I feel, it can be done without for loops, but I don’t understand how to get rid of them

Thanks in advance

Hi Podi!

You need to address two issues to avoid explicit loops:

First, you have to rework feature[target == i]. This is because the
shape of feature[target == i] depends on the value of i. Because
pytorch does not support ragged tensors (tensors whose slices don’t
have the same shape), you won’t be able to replace this version of the
loop with operations on a single tensor.

Therefore we will mask feature rather than indexing into it.

Second your i and j loops index over the upper triangle of the matrix
you would get if you wrote diff out in matrix form. Pytorch does not
(directly) support triangular matrices.

Therefore we will compute the full square diff matrix (which will be
symmetric, so there will be some redundancy, but that’s okay), and use
triu() to give us just the elements we want to sum over.

Thus:

>>> import torch
>>> print (torch.__version__)
1.10.2
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> feature = torch.randn (10, 20)
>>> target = torch.randint (0, 4, (10,))
>>>
>>> total = 0
>>> for i in range(0, 4):
...   for j in range(i + 1, 4):
...     diff = feature[target == i].mean(0) - feature[target == j].mean(0)
...     total += torch.norm(diff, 2)
...
>>> total
tensor(32.2318)
>>>
>>> mask = target.unsqueeze (-1) == torch.arange (4)
>>> means = ((feature.unsqueeze (1) * mask.unsqueeze (-1)).sum (0) / mask.sum (0).unsqueeze (-1))
>>> totalB = torch.norm (means.unsqueeze (0) - means.unsqueeze (1), 2, dim = 2).triu().sum()
>>> totalB
tensor(32.2318)

As an aside, if target == i sums to zero for any value of i, .mean() won’t
really be defined, and you will get nan:

>>> target
tensor([3, 0, 0, 0, 2, 3, 3, 1, 0, 1])
>>> target[4] = 1
>>> (target == 2).sum()
tensor(0)
>>> total = 0
>>> for i in range(0, 4):
...   for j in range(i + 1, 4):
...     diff = feature[target == i].mean(0) - feature[target == j].mean(0)
...     total += torch.norm(diff, 2)
...
>>> total
tensor(nan)
>>> mask = target.unsqueeze (-1) == torch.arange (4)
>>> means = ((feature.unsqueeze (1) * mask.unsqueeze (-1)).sum (0) / mask.sum (0).unsqueeze (-1))
>>> totalB = torch.norm (means.unsqueeze (0) - means.unsqueeze (1), 2, dim = 2).triu().sum()
>>> totalB
tensor(nan)

Best.

K. Frank

3 Likes