K-means plotting torch tensor


This is a home-made implementation of a K-means Algorith for Pytorch.
I have a tensor of dimensions [80, 1000] that represents the centroids of the cluster that go changing until they are fixed values.

Also there are the labels of the features that are considered the “centers” in the variable called “indices_”.

I am having some issues when i want to represent the tensor. I am just considering 2 main classes.

I have also tried the scatter but have not very much info on how i could do a propper plot.

Here is the code below:

def k_means_torch(dictionary, model):
    centroids = torch.randn(len(dictionary), 1000).cuda()
    dist_centroids = torch.cdist(dictionary,centroids, p=2.0)
    (values, indices) = torch.min(dist_centroids, dim=1)
    centroids_new = dictionary[indices]
    while True:
        dist_centroids_loop = torch.cdist(dictionary,centroids_new, p=2.0)
        (values_, indices_) = torch.min(dist_centroids_loop, dim=1)
        new_centers = dictionary[indices_]
        torch.allclose(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.000000001], [3., 4.]]))

        a = torch.all(torch.lt(torch.abs(torch.add(centroids_new, -new_centers)), 1e-5))
        if (a == True):
        rng = np.random.RandomState(0)
        centroids_new = (new_centers + centroids_new)/2
        #plt.scatter(centroids_new[0].cpu().numpy(), dictionary[0].cpu().numpy(),   alpha=0.5)

    return centroids_new

¿Any ideas on how i could plot this? Maybe use another function on the matplotlib?

It seems you are working with a 1000-dimensional feature space, which would be hard to plot.
You could start by plotting only two dimensions and use plt.scatter and check the results first.
To reduce your feature space you could use e.g. PCA, but I’m not sure how this would affect the kmeans centroids and data. Might still be worth a try.

Yeah… What i wanted to implement unfortunately is not possible. I have had a look to PCA but its not exactly what i was looking for. Anyways thanks for the idea :wink: