I am implementing a topography constraining based neural network layer. This layer can be thought of as being akin to a 2D grid map. It consists of 4 arguments, viz., height, width, latent-dimensionality and p-norm (for distance computations). Each unit/neuron has dimensionality equal to latent-dim. The code for this class is:
class Topography(nn.Module):
def __init__(
self, latent_dim:int = 128,
height:int = 20, width:int = 20,
p_norm:int = 2
):
super().__init__()
self.latent_dim = latent_dim
self.height = height
self.width = width
self.p_norm = p_norm
# Create 2D tensor containing 2D coords of indices
locs = np.array(list(np.array([i, j]) for i in range(self.height) for j in range(self.width)))
self.locations = torch.from_numpy(locs).to(torch.float32)
del locs
# Linear layer's trainable weights-
self.lin_wts = nn.Parameter(data = torch.empty(self.height * self.width, self.latent_dim), requires_grad = True)
# Gaussian initialization with mean = 0 and std-dev = 1 / sqrt(d)-
self.lin_wts.data.normal_(mean = 0.0, std = 1 / np.sqrt(self.latent_dim))
def forward(self, z):
# L2-normalize 'z' to convert it to unit vector-
z = F.normalize(z, p = self.p_norm, dim = 1)
# Pairwise squared L2 distance of each input to all SOM units (L2-norm distance)-
pairwise_squaredl2dist = torch.square(
torch.cdist(
x1 = z,
# Also convert all lin_wts to a unit vector-
x2 = F.normalize(input = self.lin_wts, p = self.p_norm, dim = 1),
p = self.p_norm
)
)
# For each input zi, compute closest units in 'lin_wts'-
closest_indices = torch.argmin(pairwise_squaredl2dist, dim = 1)
# Get 2D coord indices-
closest_2d_indices = self.locations[closest_indices]
# Compute L2-dist between closest unit and every other unit-
l2_dist_squared_topo_neighb = torch.square(torch.cdist(x1 = closest_2d_indices.to(torch.float32), x2 = self.locations, p = self.p_norm))
del closest_indices, closest_2d_indices
return l2_dist_squared_topo_neighb, pairwise_squaredl2dist
For a given input ‘z’, it computes closest unit to it and then creates a topography structure around that closest unit using a Radial Basis Function kernel/Gaussian (inverse) function - done in topo_neighb
tensor below.
Since torch.argmin()
gives indices similar to one-hot encoded vectors which are by definition non-differentiable, I am trying to create a work around that:
# Number of 2D units-
height = 20
width = 20
# Each unit has dimensionality specified as-
latent_dim = 128
# Use L2-norm for distance computations-
p_norm = 2
topo_layer = Topography(latent_dim = latent_dim, height = height, width = width, p_norm = p_norm)
optimizer = torch.optim.SGD(params = topo_layer.parameters(), lr = 0.001, momentum = 0.9)
batch_size = 1024
# Create an input vector-
z = torch.rand(batch_size, latent_dim)
l2_dist_squared_topo_neighb, pairwise_squaredl2dist = topo_layer(z)
# l2_dist_squared_topo_neighb.size(), pairwise_squaredl2dist.size()
# (torch.Size([1024, 400]), torch.Size([1024, 400]))
curr_sigma = torch.tensor(5.0)
# Compute Gaussian topological neighborhood structure wrt closest unit-
topo_neighb = torch.exp(torch.div(torch.neg(l2_dist_squared_topo_neighb), ((2.0 * torch.square(curr_sigma)) + 1e-5)))
# Compute topographic loss-
loss_topo = (topo_neighb * pairwise_squaredl2dist).sum(dim = 1).mean()
loss_topo.backward()
optimizer.step()
Now, the cost function’s value changes and decreases. Also, as sanity check, I am logging the L2-norm of topo_layer.lin_wts
to reflect that its weights are being updated using gradients.
Is this a correct implementation, or am I missing something?