I have implemented a Self-Organizing Map (SOM) which wraps around both and x and y axis thereby making it toroidal in PyTorch. The code for it is:
class DESOM_linearlayer(nn.Module):
def __init__(
self, latent_dim:int = 50,
map_height:int = 10, map_width:int = 10,
p_norm:int = 2, dist_metric:str = 'l2_dist'
):
super(DESOM_linearlayer, self).__init__()
self.latent_dim = latent_dim
self.map_height = map_height
self.map_width = map_width
self.som_nodes = self.map_height * self.map_width
self.p_norm = p_norm
self.dist_metric = dist_metric
# Create 2D tensor containing 2D coords of SOM's indices
locs = np.array(list(np.array([i, j]) for i in range(self.map_height) for j in range(self.map_width)))
self.locations = torch.from_numpy(locs).to(torch.float32).cuda()
del locs
# SOM trainable weights-
self.som_wts = nn.Parameter(data = torch.empty(self.map_height * self.map_width, self.latent_dim), requires_grad = True)
# Initialize weights as a Gaussian initialization with mean = 0 and std-dev = 1 / sqrt(d)-
self.som_wts.data.normal_(mean = 0.0, std = 1 / np.sqrt(self.latent_dim))
def compute_wrapped_distances_batchmodel(self, coords):
'''
Compute the distances from a tensor of batch of 2D coords BMUs
The distance is computed using a wrapped torus shape,
considering wrapping around both axes.
Parameters:
1. (H, W): height & width of 2D SOM.
2. coords (torch.Tensor): A tensor of shape (batch_size, 2)
representing the (x, y) coordinates.
Returns:
torch.Tensor: A tensor of torus distances b/w BMU & all SOMs
with shape (batch_size, H, W).
'''
# batch_size = bmu_2d_indices.shape[0]
batch_size = coords.shape[0]
# Create a grid of coordinates-
# x_grid, y_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing = 'ij')
x_grid, y_grid = torch.meshgrid(torch.arange(self.map_height), torch.arange(self.map_width), indexing = 'ij')
# x_grid.shape, y_grid.shape
# (torch.Size([20, 20]), torch.Size([20, 20]))
# Expand grid to match batch size-
x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1)
y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1)
# x_grid.shape, y_grid.shape
# (torch.Size([1024, 20, 20]), torch.Size([1024, 20, 20]))
# Extract x and y coordinates from 'coords'-
# x_coords = bmu_2d_indices[:, 0].view(-1, 1, 1)
x_coords = coords[:, 0].view(-1, 1, 1)
# y_coords = bmu_2d_indices[:, 1].view(-1, 1, 1)
y_coords = coords[:, 1].view(-1, 1, 1)
# x_coords.shape, y_coords.shape
# (torch.Size([1024, 1, 1]), torch.Size([1024, 1, 1]))
# Move to GPU device-
# x_grid = x_grid.to(device)
x_grid = x_grid.cuda()
# y_grid = y_grid.to(device)
y_grid = y_grid.cuda()
# Compute differences-
dx = x_grid - x_coords
dy = y_grid - y_coords
# dx.shape, dy.shape
# (torch.Size([1024, 20, 20]), torch.Size([1024, 20, 20]))
# Wrap the differences-
# dx_wrapped = torch.where(dx > H//2, dx - H, torch.where(dx < -H//2, dx + H, dx))
dx_wrapped = torch.where(dx > self.map_height // 2, dx - self.map_height, torch.where(dx < -self.map_height // 2, dx + self.map_height, dx))
# dy_wrapped = torch.where(dy > W//2, dy - W, torch.where(dy < -W//2, dy + W, dy))
dy_wrapped = torch.where(dy > self.map_width // 2, dy - self.map_width, torch.where(dy < -self.map_width // 2, dy + self.map_width, dy))
# dx_wrapped.shape, dy_wrapped.shape
# (torch.Size([1024, 20, 20]), torch.Size([1024, 20, 20]))
# Compute Euclidean distance-
distances = torch.sqrt(torch.square(dx_wrapped.float()) + torch.square(dy_wrapped.float()))
# distances.shape
# torch.Size([1024, 20, 20])
return distances
def forward(self, z):
# L2-normalize both 'z' and SOM weights:
if self.dist_metric == 'l2_dist':
# L2-normalize 'z' (convert it into a unit vector)-
z = F.normalize(input = z, p = self.p_norm, dim = 1)
# Pairwise squared L2 distance of each input to all SOM units (L2-norm distance)-
pairwise_squaredl2dist = torch.square(
torch.cdist(
x1 = z,
# Also convert all SOM wts to a unit vector-
x2 = F.normalize(input = self.som_wts, p = self.p_norm, dim = 1),
p = self.p_norm
)
)
# Get BMU indices (1d array)-
bmu_indices = torch.argmin(pairwise_squaredl2dist, dim = 1)
# Get 2D BMU indices-
bmu_2d_indices = self.locations[bmu_indices]
# Convert from 2D indices to wrapped/toroidal coordinates-
l2_bmu_dist = self.compute_wrapped_distances_batchmodel(bmu_2d_indices)
return l2_bmu_dist, pairwise_squaredl2dist
Here l2_bmu_dist
is the pairwise distance of a best matching unit/winner for a given input āzā to all other SOM units, and pairwise_squaredl2dist
is the pair-wise distance of a given input āzā to all SOM units.
I am looking for suggestions related to:
- the correct wrapping to convert it from a usual 2D grid coordinates to a toroidal implementation which is handled by
compute_wrapped_distances_batchmodel()
- further improvements