Find winning unit between 2 torch tensors of different shapes

I am trying to implement a Self-Organizing Map where for a given input sample, the best matching unit/winning unit is chosen based on (say) L2-norm distance between the SOM and the input. To implement this, I have:

# Input batch: batch-size = 512, input-dim = 84-
z = torch.randn(512, 84)

# SOM shape: (height, width, input-dim)-
som = torch.randn(40, 40, 84)

# Compute L2 distance for a single sample out of 512 samples-
dist_l2 = np.linalg.norm((som.numpy() - z[0].numpy()), ord = 2, axis = 2)

# dist_l2.shape
# (40, 40)

# Get (row, column) index of the minimum of a 2d np array-
row, col = np.unravel_index(dist_l2.argmin(), dist_l2.shape)

print(f"BMU for z[0]; row = {row}, col  = {col}")
# BMU for z[0]; row = 3, col  = 9

So for the first input sample of ā€˜zā€™, the winning unit in SOM has the index: (3, 9). I can put this in a for loop iterating over all 512 such input samples, but that is very inefficient.

Is there an efficient vectorized PyTorch manner to compute this for the entire batch?

This should work:

# Input batch: batch-size = 512, input-dim = 84-
z = torch.randn(512, 84)

# SOM shape: (height, width, input-dim)-
som = torch.randn(40, 40, 84)

# Compute L2 distance for a single sample out of 512 samples-
dist_l2 = np.linalg.norm((som.numpy() - z[0].numpy()), ord = 2, axis = 2)
out = np.linalg.norm(som.numpy()[:, :, None, :] - z[None, None, :, :].numpy(), ord=2, axis=3)

print((out[:, :, 0] == dist_l2).all())
# True

# Using PyTorch
out = torch.linalg.norm(som[:, :, None, :] - z[None, None, :, :], ord=2, dim=3)
print((out[:, :, 0] - torch.from_numpy(dist_l2)).abs().max())
# tensor(1.9073e-06)
1 Like