I am trying to implement a Self-Organizing Map where for a given input sample, the best matching unit/winning unit is chosen based on (say) L2-norm distance between the SOM and the input. To implement this, I have:
# Input batch: batch-size = 512, input-dim = 84-
z = torch.randn(512, 84)
# SOM shape: (height, width, input-dim)-
som = torch.randn(40, 40, 84)
# Compute L2 distance for a single sample out of 512 samples-
dist_l2 = np.linalg.norm((som.numpy() - z[0].numpy()), ord = 2, axis = 2)
# dist_l2.shape
# (40, 40)
# Get (row, column) index of the minimum of a 2d np array-
row, col = np.unravel_index(dist_l2.argmin(), dist_l2.shape)
print(f"BMU for z[0]; row = {row}, col = {col}")
# BMU for z[0]; row = 3, col = 9
So for the first input sample of āzā, the winning unit in SOM has the index: (3, 9). I can put this in a for loop iterating over all 512 such input samples, but that is very inefficient.
Is there an efficient vectorized PyTorch manner to compute this for the entire batch?