K-means Loss Calculation

Can someone give an idea on how to implement k-means clustering loss in pytorch?

39

Also I am using Pytorch nn.mse() loss. Is there a way to add L2 reguarization to this term. In short, if I want to use L2-Reg. loss.

Could you explain, what s_ik is?
Which parameter do you want to use the L2 reg. on?

@ptrblck
For the K-means loss, this is what I am doing. I just want to verify if its correct . Or is there any other clean way of doing this?

import torch

class KMeansClusteringLoss(torch.nn.Module):

def __init__(self):
    super(KMeansClusteringLoss,self).__init__()

def forward(self, encode_output, centroids):
    assert (encode_output.shape[1] == centroids.shape[1]),"Dimensions Mismatch"
    n = encode_output.shape[0]
    d = encode_output.shape[1]
    k = centroids.shape[0]

    z = encode_output.reshape(n,1,d)
    z = z.repeat(1,k,1)

    mu = centroids.reshape(1,k,d)
    mu = mu.repeat(n,1,1)

    dist = (z-mu).norm(2,dim=2).reshape((n,k))
    loss = dist.min(dim=1)[0].mean()

    return loss

s_ik is bascially one-hot vector which is 1 if data point i belongs to cluster k.

And for L2-reg. I simply want to implement Ridge Regression: Loss + \lambda || w ||_2.

where \lambda would be a hyperparameter and Loss = nn.mse().

I’d probably not use repeat but let the broadcasting do it’s thing. Also, personally I prefer unsqueeze or indexing (z = z[None, :, :]) over reshape to get the new dimensions.

Best regards

Thomas

@tom Can you please explain a bit on how exactly should this be done? And is there some advantage of using broadcasting over repeat ?

If you use triple backticks (```python) before and just the backtics (```) after your code, it will be well-formatted.

In Jupyter:

import torch

class KMeansClusteringLoss(torch.nn.Module):
    def __init__(self):
        super(KMeansClusteringLoss,self).__init__()

    def forward(self, encode_output, centroids):
        assert (encode_output.shape[1] == centroids.shape[1]),"Dimensions Mismatch"
        n = encode_output.shape[0]
        d = encode_output.shape[1]
        k = centroids.shape[0]

        z = encode_output.reshape(n,1,d)
        z = z.repeat(1,k,1)

        mu = centroids.reshape(1,k,d)
        mu = mu.repeat(n,1,1)

        dist = (z-mu).norm(2,dim=2).reshape((n,k))
        loss = (dist.min(dim=1)[0]**2).mean() # note that you didn't have the square loss as the equation

        return loss

def sq_loss_clusters(encode_output, centroids):
    assert encode_output.size(1) == centroids.size(1), "Dimension mismatch"
    return ((encode_output[:, None]-centroids[None])**2).sum(2).min(1)[0].mean()
loss_fn = KMeansClusteringLoss()
pts = torch.randn(2000,10)
means = torch.randn(10,10)
l1 = loss_fn(pts, means)
l2 = sq_loss_clusters(pts, means)
assert (l1==l2).all()
%timeit loss_fn(pts, means)
%timeit sq_loss_clusters(pts, means)

has 3.07ms for your version and 684µs for mine, so it’s >4 faster.
I think it’s more clear what’s going on, too, if you write it more concise (maybe add a comment linking to the equation and commenting that s_ik == 1 if the assignment of point i to cluster k is optimal given fixed centers, or spread it over several lines to be better able to follow - tastes vary and that is fine).

Note that the equation you posted had the square distance, so I used that.

Best regards

Thomas

1 Like

@tom Thank you so much. I will try this out.

Could you also guide on how to use L2-regularization with Mean Squared Error loss?

Well, either use the weight decay of the optimizers (which adds a multiple the derivative of |w|² to the gradient) or sum |w|² yourself (reg_loss = sum([p**2 for p in m.parameters()] or so should work).

Best regards

Thomas

@tom
What is the suggested way of doing this and why?

I’m not an expert on this, my rule of thumb would be to use built-in weight decay for efficiency if it works for you.
But really, I don’t have any profound insights to offer here.

Best regards

Thomas