Hi Podi!
You need to address two issues to avoid explicit loops:
First, you have to rework feature[target == i]
. This is because the
shape
of feature[target == i]
depends on the value of i
. Because
pytorch does not support ragged tensors (tensors whose slices don’t
have the same shape), you won’t be able to replace this version of the
loop with operations on a single tensor.
Therefore we will mask feature
rather than indexing into it.
Second your i
and j
loops index over the upper triangle of the matrix
you would get if you wrote diff
out in matrix form. Pytorch does not
(directly) support triangular matrices.
Therefore we will compute the full square diff
matrix (which will be
symmetric, so there will be some redundancy, but that’s okay), and use
triu()
to give us just the elements we want to sum over.
Thus:
>>> import torch
>>> print (torch.__version__)
1.10.2
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> feature = torch.randn (10, 20)
>>> target = torch.randint (0, 4, (10,))
>>>
>>> total = 0
>>> for i in range(0, 4):
... for j in range(i + 1, 4):
... diff = feature[target == i].mean(0) - feature[target == j].mean(0)
... total += torch.norm(diff, 2)
...
>>> total
tensor(32.2318)
>>>
>>> mask = target.unsqueeze (-1) == torch.arange (4)
>>> means = ((feature.unsqueeze (1) * mask.unsqueeze (-1)).sum (0) / mask.sum (0).unsqueeze (-1))
>>> totalB = torch.norm (means.unsqueeze (0) - means.unsqueeze (1), 2, dim = 2).triu().sum()
>>> totalB
tensor(32.2318)
As an aside, if target == i
sums to zero for any value of i
, .mean()
won’t
really be defined, and you will get nan
:
>>> target
tensor([3, 0, 0, 0, 2, 3, 3, 1, 0, 1])
>>> target[4] = 1
>>> (target == 2).sum()
tensor(0)
>>> total = 0
>>> for i in range(0, 4):
... for j in range(i + 1, 4):
... diff = feature[target == i].mean(0) - feature[target == j].mean(0)
... total += torch.norm(diff, 2)
...
>>> total
tensor(nan)
>>> mask = target.unsqueeze (-1) == torch.arange (4)
>>> means = ((feature.unsqueeze (1) * mask.unsqueeze (-1)).sum (0) / mask.sum (0).unsqueeze (-1))
>>> totalB = torch.norm (means.unsqueeze (0) - means.unsqueeze (1), 2, dim = 2).triu().sum()
>>> totalB
tensor(nan)
Best.
K. Frank