# Torch.argmin() non-differentiability workaround

I am implementing a topography constraining based neural network layer. This layer can be thought of as being akin to a 2D grid map. It consists of 4 arguments, viz., height, width, latent-dimensionality and p-norm (for distance computations). Each unit/neuron has dimensionality equal to latent-dim. The code for this class is:

``````class Topography(nn.Module):
def __init__(
self, latent_dim:int = 128,
height:int = 20, width:int = 20,
p_norm:int = 2
):
super().__init__()

self.latent_dim = latent_dim
self.height = height
self.width = width
self.p_norm = p_norm

# Create 2D tensor containing 2D coords of indices
locs = np.array(list(np.array([i, j]) for i in range(self.height) for j in range(self.width)))
self.locations = torch.from_numpy(locs).to(torch.float32)
del locs

# Linear layer's trainable weights-
self.lin_wts = nn.Parameter(data = torch.empty(self.height * self.width, self.latent_dim), requires_grad = True)

# Gaussian initialization with mean = 0 and std-dev = 1 / sqrt(d)-
self.lin_wts.data.normal_(mean = 0.0, std = 1 / np.sqrt(self.latent_dim))

def forward(self, z):

# L2-normalize 'z' to convert it to unit vector-
z = F.normalize(z, p = self.p_norm, dim = 1)

# Pairwise squared L2 distance of each input to all SOM units (L2-norm distance)-
pairwise_squaredl2dist = torch.square(
torch.cdist(
x1 = z,
# Also convert all lin_wts to a unit vector-
x2 = F.normalize(input = self.lin_wts, p = self.p_norm, dim = 1),
p = self.p_norm
)
)

# For each input zi, compute closest units in 'lin_wts'-
closest_indices = torch.argmin(pairwise_squaredl2dist, dim = 1)

# Get 2D coord indices-
closest_2d_indices = self.locations[closest_indices]

# Compute L2-dist between closest unit and every other unit-
l2_dist_squared_topo_neighb = torch.square(torch.cdist(x1 = closest_2d_indices.to(torch.float32), x2 = self.locations, p = self.p_norm))
del closest_indices, closest_2d_indices

return l2_dist_squared_topo_neighb, pairwise_squaredl2dist
``````

For a given input ‘z’, it computes closest unit to it and then creates a topography structure around that closest unit using a Radial Basis Function kernel/Gaussian (inverse) function - done in `topo_neighb` tensor below.

Since `torch.argmin()` gives indices similar to one-hot encoded vectors which are by definition non-differentiable, I am trying to create a work around that:

``````# Number of 2D units-
height = 20
width = 20

# Each unit has dimensionality specified as-
latent_dim = 128

# Use L2-norm for distance computations-
p_norm = 2

topo_layer = Topography(latent_dim = latent_dim, height = height, width = width, p_norm = p_norm)

optimizer = torch.optim.SGD(params = topo_layer.parameters(), lr = 0.001, momentum = 0.9)

batch_size = 1024

# Create an input vector-
z = torch.rand(batch_size, latent_dim)

l2_dist_squared_topo_neighb, pairwise_squaredl2dist = topo_layer(z)

# l2_dist_squared_topo_neighb.size(), pairwise_squaredl2dist.size()
# (torch.Size([1024, 400]), torch.Size([1024, 400]))

curr_sigma = torch.tensor(5.0)

# Compute Gaussian topological neighborhood structure wrt closest unit-
topo_neighb = torch.exp(torch.div(torch.neg(l2_dist_squared_topo_neighb), ((2.0 * torch.square(curr_sigma)) + 1e-5)))

# Compute topographic loss-
loss_topo = (topo_neighb * pairwise_squaredl2dist).sum(dim = 1).mean()

loss_topo.backward()

optimizer.step()
``````

Now, the cost function’s value changes and decreases. Also, as sanity check, I am logging the L2-norm of `topo_layer.lin_wts` to reflect that its weights are being updated using gradients.

Is this a correct implementation, or am I missing something?

It seems to me that you are somehow reimplementing `Softmin` with x being the l2 norms you want to take the `min` over and `curr_sigma=1`. Maybe using `Softmin` to compute the “topological neighborhood” directly would save you a couple of lines of code…