Weight clustering

Is there an equivalent implementation for weight clustering in pytorch as we have in tensorflow : Weight clustering Tesnsorflow

If there is not then can someone can someone help me confirming what I have done seems the right thing to do:

from sklearn.cluster import KMeans
# from kmeans_pytorch import kmeans, kmeans_predict

'''Going to go through all the layers --> obtain their weights--> use scikit learn to do the clustering--> replace these weight by their centroids'''
def weight_clustering(model):
  model.to('cpu')
  with torch.no_grad():
    for name, params in model.named_parameters():
      param_shape=list(params.size())
      # print("the shape is ",param_shape)
      weights=torch.flatten(params)
      weights=params.reshape(-1,1)
      kmeans = KMeans(n_clusters=3, random_state=0).fit(weights)

      cluster_centers=torch.from_numpy(kmeans.cluster_centers_)
      for i in range(0,len(kmeans.labels_)):
         if kmeans.labels_[i]==0:
          weights[i]=cluster_centers[0]
         elif kmeans.labels_[i]==1:
          weights[i]=cluster_centers[1]
         elif kmeans.labels_[i]==2:
          weights[i]=cluster_centers[2]
         elif kmeans.labels_[i]==3:
          weights[i]=cluster_centers[3]
         else :
          weights[i]=cluster_centers[4]
         

      reshape_size_tuple=tuple(param_shape)
      weights=weights.reshape(reshape_size_tuple)
      if(params.shape==weights.data.shape):
        params.data=weights.data

Currently, I have tried this very crude way of doing it, but I am not sure this is the right way. Can someone confirm if this is the right way or is there a better/correct way to do the task?

I think that the link that I used in the question does not open.The following should be the correct link :Weight clustering

I don’t think the code you’ve posted achieves your desired results.
Based on the description from the TF weight clustering approach it seems as if parameters are reused (i.e. views are actually used) in the new layers, so that the original memory footprint would be reduced.
In your current code snippet you would replace some values of the parameters, but all values would still be stored independently, wouldn’t they?

Hello @ptrblck , yes, I want to achieve a reduction in the memory footprint as the link suggests.When I check the above code, as you correctly mentioned there is no reduction in the memory footprint. The model remains of the same size. I will try to make use of ‘views’ as you have suggested.

@ptrblck I tried something like this, although I don’t think it’s entirely correct. I tried to validate using print statements and they all return true. However, with this, the model size increases. Could you direct me in the right direction maybe?

def weight_clustering(model):
  model.to('cpu')
  with torch.no_grad():
    count=0
    for name, params in model.backbone.named_parameters():
      param_shape=list(params.size()) 
      weights=params.reshape(-1,1)
      kmeans = KMeans(n_clusters=5, random_state=0).fit(weights)
      cluster_centers=torch.from_numpy(kmeans.cluster_centers_)
      print("Processing for layer ",count)
      count+=1
      # print(type(cluster_centers))
      cluster_list=[]
      for i in range(0,len(kmeans.labels_)):
         if kmeans.labels_[i]==0:
          cluster_list.append(cluster_centers[0].view(1))
          # print(cluster_list[i].data_ptr() == cluster_centers[0].data_ptr())
         elif kmeans.labels_[i]==1:
          cluster_list.append(cluster_centers[1].view(1))
          # print(cluster_list[i].data_ptr() == cluster_centers[1].data_ptr())
         elif kmeans.labels_[i]==2:
          cluster_list.append(cluster_centers[2].view(1))
          # print(cluster_list[i].data_ptr() == cluster_centers[2].data_ptr())
         elif kmeans.labels_[i]==3:
          cluster_list.append(cluster_centers[3].view(1))
          # print(cluster_list[i].data_ptr() == cluster_centers[3].data_ptr())
         elif kmeans.labels_[i]==4:
          cluster_list.append(cluster_centers[4].view(1))
          # print(cluster_list[i].data_ptr() == cluster_centers[4].data_ptr())
 
      reshape_size_tuple=tuple(param_shape)
      cluster_list=torch.tensor(cluster_list)
      cluster_list=cluster_list.reshape(reshape_size_tuple)
      params.data=cluster_list.data
      print(params.data_ptr() == cluster_list.data_ptr())

  return model

model_check=weight_clustering(model)