Self-Organizing Map - toroidal wrapping; Sanity check

I have implemented a Self-Organizing Map (SOM) which wraps around both and x and y axis thereby making it toroidal in PyTorch. The code for it is:

class DESOM_linearlayer(nn.Module):
    def __init__(
        self, latent_dim:int = 50,
        map_height:int = 10, map_width:int = 10,
        p_norm:int = 2, dist_metric:str = 'l2_dist'
    ):
        super(DESOM_linearlayer, self).__init__()

        self.latent_dim = latent_dim
        self.map_height = map_height
        self.map_width = map_width
        self.som_nodes = self.map_height * self.map_width
        self.p_norm = p_norm
        self.dist_metric = dist_metric

        # Create 2D tensor containing 2D coords of SOM's indices
        locs = np.array(list(np.array([i, j]) for i in range(self.map_height) for j in range(self.map_width)))
        self.locations = torch.from_numpy(locs).to(torch.float32).cuda()
        del locs

        # SOM trainable weights-
        self.som_wts = nn.Parameter(data = torch.empty(self.map_height * self.map_width, self.latent_dim), requires_grad = True)

        # Initialize weights as a Gaussian initialization with mean = 0 and std-dev = 1 / sqrt(d)-
        self.som_wts.data.normal_(mean = 0.0, std = 1 / np.sqrt(self.latent_dim))


    def compute_wrapped_distances_batchmodel(self, coords):
        '''
        Compute the distances from a tensor of batch of 2D coords BMUs
        The distance is computed using a wrapped torus shape,
        considering wrapping around both axes.

        Parameters:
        1. (H, W): height & width of 2D SOM.
        2. coords (torch.Tensor): A tensor of shape (batch_size, 2)
        representing the (x, y) coordinates.

        Returns:
        torch.Tensor: A tensor of torus distances b/w BMU & all SOMs
        with shape (batch_size, H, W).
        '''
        # batch_size = bmu_2d_indices.shape[0]
        batch_size = coords.shape[0]

        # Create a grid of coordinates-
        # x_grid, y_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing = 'ij')
        x_grid, y_grid = torch.meshgrid(torch.arange(self.map_height), torch.arange(self.map_width), indexing = 'ij')

        # x_grid.shape, y_grid.shape
        # (torch.Size([20, 20]), torch.Size([20, 20]))

        # Expand grid to match batch size-
        x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1)
        y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1)

        # x_grid.shape, y_grid.shape
        # (torch.Size([1024, 20, 20]), torch.Size([1024, 20, 20]))

        # Extract x and y coordinates from 'coords'-
        # x_coords = bmu_2d_indices[:, 0].view(-1, 1, 1)
        x_coords = coords[:, 0].view(-1, 1, 1)
        # y_coords = bmu_2d_indices[:, 1].view(-1, 1, 1)
        y_coords = coords[:, 1].view(-1, 1, 1)

        # x_coords.shape, y_coords.shape
        # (torch.Size([1024, 1, 1]), torch.Size([1024, 1, 1]))

        # Move to GPU device-
        # x_grid = x_grid.to(device)
        x_grid = x_grid.cuda()
        # y_grid = y_grid.to(device)
        y_grid = y_grid.cuda()

        # Compute differences-
        dx = x_grid - x_coords
        dy = y_grid - y_coords

        # dx.shape, dy.shape
        # (torch.Size([1024, 20, 20]), torch.Size([1024, 20, 20]))

        # Wrap the differences-
        # dx_wrapped = torch.where(dx > H//2, dx - H, torch.where(dx < -H//2, dx + H, dx))
        dx_wrapped = torch.where(dx > self.map_height // 2, dx - self.map_height, torch.where(dx < -self.map_height // 2, dx + self.map_height, dx))
        # dy_wrapped = torch.where(dy > W//2, dy - W, torch.where(dy < -W//2, dy + W, dy))
        dy_wrapped = torch.where(dy > self.map_width // 2, dy - self.map_width, torch.where(dy < -self.map_width // 2, dy + self.map_width, dy))

        # dx_wrapped.shape, dy_wrapped.shape
        # (torch.Size([1024, 20, 20]), torch.Size([1024, 20, 20]))

        # Compute Euclidean distance-
        distances = torch.sqrt(torch.square(dx_wrapped.float()) + torch.square(dy_wrapped.float()))

        # distances.shape
        # torch.Size([1024, 20, 20])

        return distances


    def forward(self, z):
        # L2-normalize both 'z' and SOM weights:

        if self.dist_metric == 'l2_dist':


            # L2-normalize 'z' (convert it into a unit vector)-
            z = F.normalize(input = z, p = self.p_norm, dim = 1)

            # Pairwise squared L2 distance of each input to all SOM units (L2-norm distance)-
            pairwise_squaredl2dist = torch.square(
                torch.cdist(
                    x1 = z,
                    # Also convert all SOM wts to a unit vector-
                    x2 = F.normalize(input = self.som_wts, p = self.p_norm, dim = 1),
                    p = self.p_norm
                )
            )

            # Get BMU indices (1d array)-
            bmu_indices = torch.argmin(pairwise_squaredl2dist, dim = 1)

            # Get 2D BMU indices-
            bmu_2d_indices = self.locations[bmu_indices]

            # Convert from 2D indices to wrapped/toroidal coordinates-
            l2_bmu_dist = self.compute_wrapped_distances_batchmodel(bmu_2d_indices)

            return l2_bmu_dist, pairwise_squaredl2dist

Here l2_bmu_dist is the pairwise distance of a best matching unit/winner for a given input ā€˜zā€™ to all other SOM units, and pairwise_squaredl2dist is the pair-wise distance of a given input ā€˜zā€™ to all SOM units.

I am looking for suggestions related to:

  1. the correct wrapping to convert it from a usual 2D grid coordinates to a toroidal implementation which is handled by compute_wrapped_distances_batchmodel()
  2. further improvements