Cannot use k_means.fit_predict(x) on the output of a pre-trained encoder

I have the test set of MNIST dataset and I want to give the images to a pre-trained encoder and then cluster the embedded images using k-means clustering but I get an error when trying to fit_predict().

This is the code:

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])

test_set = dset.MNIST(root=root, train=False, transform=trans, download=True)

test_loader =
    dataset = test_set,
    batch_size = 10000,
    shuffle = False)

km = KMeans(k, n_init=20, n_jobs=4)
for data in test_loader:
    x, _= data
    x = model(x.cuda())
    x =
    #x = x.astype(int)
    y_pred = km.fit_predict(x) # seems we can only get a centre from batch
    sil_score = sil(x, y_pred)
    print('sil score', sil_score)

And this is the error I get:

_RemoteTraceback                          Traceback (most recent call last)
TypeError: 'float' object cannot be interpreted as an integer

I tried adding x = x.astype(int):

km = KMeans(k, n_init=20, n_jobs=4)
for data in test_loader_0:
    x, _= data
    x = model(x.cuda())
    x =
    x = x.astype(int)
    y_pred = km.fit_predict(x) # seems we can only get a centre from batch
    sil_score = sil(x, y_pred)
    print('sil score', sil_score)

but got the same error. I find this error very strange as I have used the same dataset (training set) and the same network (model=encoder) to calculate the labels produced by k-means. I don’t think that k-means.fit_predict only accepts integer values.
I wonder if anyone has a clue about this or has encountered this issue? I appreciate a hint.

Its not the problem with X, You should be able to fit anything, not just int, the sample code below works. I doubt the K value you are passing is not an int, can you check? number of clusters has to be an int.

from sklearn.cluster import KMeans
import numpy as np
X = np.array([[1, 2], [1, 4], [1, 0], [10, 2], [10, 4], [10, 0]], dtype=float)
kmeans = KMeans(n_clusters=2, random_state=0).fit_predict(X)

out: array([1, 1, 1, 0, 0, 0], dtype=int32)

Thank you for the help. k was the issue. It was a float number and after converting it to integer my code works :slight_smile: