Replacing a for loop with indexing:kmeans

I’m trying to replace the "for loop " with indexing.

This the code :
for i in range(200):
centers, codes = cluster(dataset[i],4)

Could you post the definition of cluster?
If you want to replace the loop, cluster would have to process batched data.

Hi ptrblck,
How can I make the function cluster process batched data?The code of cluster is as follows:
import torch
import numpy as np
import random
import sys
import timeit
device_gpu = torch.device(‘cuda’)

Choosing num_centers random data points as the initial centers

def random_init(dataset, num_centers):
num_points = dataset.size(0)
dimension = dataset.size(1)
used = torch.zeros(num_points, dtype=torch.long)
indices = torch.zeros(num_centers, dtype=torch.long)
for i in range(num_centers):
while True:
cur_id = random.randint(0, num_points - 1)
if used[cur_id] > 0:
used[cur_id] = 1
indices[i] = cur_id
indices =
centers = torch.gather(dataset, 0, indices.view(-1, 1).expand(-1, dimension))
return centers

Compute for each data point the closest center

def compute_codes(dataset, centers):
num_points = dataset.size(0)
dimension = dataset.size(1)
num_centers = centers.size(0)
# 5e8 should vary depending on the free memory on the GPU
# Ideally, automatically :wink:
chunk_size = int(5e8 / num_centers)
codes = torch.zeros(num_points, dtype=torch.long, device=device_gpu)
centers_t = torch.transpose(centers, 0, 1)
centers_norms = torch.sum(centers ** 2, dim=1).view(1, -1)
for i in range(0, num_points, chunk_size):
begin = i
end = min(begin + chunk_size, num_points)
dataset_piece = dataset[begin:end, :]
dataset_norms = torch.sum(dataset_piece ** 2, dim=1).view(-1, 1)
distances =, centers_t)
distances *= -2.0
distances += dataset_norms
distances += centers_norms
_, min_ind = torch.min(distances, dim=1)
codes[begin:end] = min_ind
return codes

Compute new centers as means of the data points forming the clusters

def update_centers(dataset, codes, num_centers):
num_points = dataset.size(0)
dimension = dataset.size(1)
centers = torch.zeros(num_centers, dimension, dtype=torch.float, device=device_gpu)
cnt = torch.zeros(num_centers, dtype=torch.float, device=device_gpu)
centers.scatter_add_(0, codes.view(-1, 1).expand(-1, dimension), dataset)
cnt.scatter_add_(0, codes, torch.ones(num_points, dtype=torch.float, device=device_gpu))
# Avoiding division by zero
# Not necessary if there are no duplicates among the data points
cnt = torch.where(cnt > 0.5, cnt, torch.ones(num_centers, dtype=torch.float, device=device_gpu))
centers /= cnt.view(-1, 1)
return centers
def cluster(dataset, num_centers):
centers = random_init(dataset, num_centers)
codes = compute_codes(dataset, centers)
num_iterations = 0
while True:
num_iterations += 1
centers = update_centers(dataset, codes, num_centers)
new_codes = compute_codes(dataset, centers)
# Waiting until the clustering stops updating altogether
# This is too strict in practice
if torch.equal(codes, new_codes):
codes = new_codes
return centers, codes

if name == ‘main’:

print('Starting clustering')
for i in range(200):
     centers, codes = cluster(dataset[i], 2)