How to get centroid in prototypes?

I have prototypes and each class-label.
I want to get centroid of each class. How to compute centroid ?
Here is the shape of prototypes.

print(prototypes.shape): [64,31] # batch_size : 64, number of class: 30
1 Like

Please better describe your problem, maybe by showing some examples. I can understand your problem in several different ways.

I guess that you need: torch.Tensor — PyTorch 1.7.1 documentation

Actually, that’s all I explained…
I already have prototypes so I just need to compute centroids of these prototypes according to the labels.

for inputs, labels in train_loader:
   prototypes = network(inputs) # network: Resnet + FC layer.
   # I want to compute centroids of each class.

I think computing centroid is something like this…

for i, prototypes, labels enumerate(all_samples):
   if label == i:
      centroid[i] = prototypes / number_of_class_i

But I’m not sure about this.

import torch.nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, 3)
        self.conv2 = nn.Conv2d(10, 20, 3)
        self.conv3 = nn.Conv2d(20, 30, 3)
        self.maxpool = nn.MaxPool2d((2,2))
        self.linear = nn.Linear(120, 10)


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.maxpool(x)
        
        x = F.relu(self.conv2(x))
        x = self.maxpool(x)
        
        x = F.relu(self.conv3(x))
        x = self.maxpool(x)
        
        x = x.view(x.shape[0], -1)
        x = self.linear(x)
        return x


network = Model() #example

res = []
for i in range(8): #should be: for inputs, labels in train_loader:
    inputs = torch.rand(32, 3, 32, 32) #example
    prototypes = network(inputs) # network: Resnet + FC layer.
    prototypes = prototypes.permute(1, 0)
    res.append(prototypes)
    
res = torch.cat(res, dim=1)
print(res.shape)
print(res.mean(dim=1).shape)

Here the mean is calculated over all outputs from the network - not the highest one. In the ideal situation - where there are no biases and the dataset is well balanced - all values should be similar.