Linear layer not training

I am trying to implement a Deep Embedded Self-Organizing Map (DESOM) which is an Autoencoder together with a trainable SOM as a trainable layer which I implement using a Linear layer:

class SOM(nn.Module):
    def __init__(
        self, map_height = 10, map_width = 10,
        latent_dim = 50, p_norm = 2
    ):
        super().__init__()
        
        self.map_height = map_height
        self.map_width = map_width
        self.latent_dim = latent_dim
        self.p_norm = p_norm
        self.som_nodes = self.map_height * self.map_width

        # Uniform sampling for SOM (flattened) weights initialization-
        # self.som_wts = torch.distributions.uniform.Uniform(low = - 1 / np.sqrt(latent_space_dim), high = 1 / np.sqrt(latent_space_dim)).sample((m * n, latent_space_dim))

        # create the embedding dictionary-
        # self.embedding = nn.Embedding(self.som_nodes, self.latent_dim)
        # self.embedding.weight.data.uniform_(-np.sqrt(1 / self.latent_dim), np.sqrt(1 / self.latent_dim))

        # Create SOM using linear layer (without bias)-
        self.som_wts = nn.Linear(in_features = self.latent_dim, out_features = self.map_height * self.map_width, bias = False)
        self.som_wts.weight.data.uniform_(-np.sqrt(1 / self.latent_dim), np.sqrt(1 / self.latent_dim))


    def forward(self):
        pass

The entire model combining Autoencoder with SOM is:

class DESOM(nn.Module):
    def __init__(
        self, latent_dim = 50,
        capacity = 16, map_height = 10,
        map_width = 10, p_norm = 2,
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.capacity = capacity
        self.map_height = map_height
        self.map_width = map_width
        self.p_norm = p_norm

        # tot_train_iterations = num_epochs * len(train_loader)
        # self.decay_vals = list(scheduler(it = step, tot = tot_train_iterations) for step in range(1, tot_train_iterations + 6))
        # self.decay_vals = torch.tensor(np.asarray(decay_vals))

        self.encoder = Encoder(latent_dim = self.latent_dim, capacity = self.capacity)
        self.decoder = Decoder(latent_dim = self.latent_dim, capacity = self.capacity)
        self.som = SOM(map_height = self.map_height, map_width = self.map_width, p_norm = self.p_norm)


    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return z, x_recon

# Specify SOM hyper-params-
# m = SOM height-
som_height = 20

# n = SOM width-
som_width = 20

latent_space_dim = 50

# Initialize model-
# Initialize DESOM model-
model = DESOM(
    latent_dim = latent_space_dim, capacity = 16,
    map_height = som_height, map_width = som_width,
    p_norm = p_norm
)

model.som.som_wts.weight.shape
# torch.Size([400, 50])

# Randomly initialized weights-
model.som.som_wts.weight.min().item(), model.som.som_wts.weight.max().item()
# (-0.14141587913036346, 0.14140239357948303)

Within train_one_epoch() function, the primary code to train both the Autoencoder and SOM layer is:

# Get latent code and reconstruction-
z, x_recon = model(x)

optimizer.zero_grad()

# Autoencoder reconstruction loss-
recon_loss = F.mse_loss(input = x_recon, target = x)

# SOM  training code-
l2_dist_z_soms = torch.cdist(x1 = z, x2 = model.som.som_wts.weight, p = p_norm)
mindist, bmu_indices = torch.min(l2_dist_z_soms, -1)
bmu_locations = locations[bmu_indices]
squared_l2_norm_dists = torch.square(torch.cdist(x1 = bmu_locations, x2 = locations, p = p_norm))

# Compute sigma for current iteration/step-
global step
curr_sigma_val = sigma_0 * torch.exp(-step / lmbda_val)
step += 1

# Compute Gaussian topographic neighborhood-
topo_neighb = torch.exp(-squared_l2_norm_dists / ((2 * torch.square(curr_sigma_val)) + 1e-6))

# Compute topographic loss-
topo_loss = topo_neighb * squared_l2_norm_dists

# Sum along all SOM units and mean along batch-
topo_loss = topo_loss.sum(1).mean()

# Compute total loss-
total_loss = recon_loss + (gamma * topo_loss)
# gamma = 0.001

# Compute gradienst wrt computed loss-
total_loss.backward()
        
# Perform one step of gradient descent-
optimizer.step()

The entire code can be referred here. I omit other parts for brevity.

But after the training is done, when I see the trained weights for SOM layer-

model.som.som_wts.weight.min().item(), model.som.som_wts.weight.max().item()
# (-0.14141587913036346, 0.14140239357948303)

model.som.som_wts.weight.requires_grad
# True

The SOM as a linear layer does not appear to not being trained! What’s preventing it from being trained?

I don’t see any usage of the SOM layer besides in:

l2_dist_z_soms = torch.cdist(x1 = z, x2 = model.som.som_wts.weight, p = p_norm)
mindist, bmu_indices = torch.min(l2_dist_z_soms, -1)

which is then using the detached indices in:

bmu_locations = locations[bmu_indices]

If that’s the only place where som is used, no training would be expected since the returned indices in torch.min are detached.

This is just a experimental code. I can and will move all of these computations inside the SOM layer once it is training as expected.

What can I do to prevent this detachment?

Moving the code won’t fix the issues since integer types are not usefully differentiable.
Their gradient would be zero everywhere besides the rounding point where the gradient would be undefined or Inf.