I have prototypes and each class-label.
I want to get centroid of each class. How to compute centroid ?
Here is the shape of prototypes.
print(prototypes.shape): [64,31] # batch_size : 64, number of class: 30
I have prototypes and each class-label.
I want to get centroid of each class. How to compute centroid ?
Here is the shape of prototypes.
print(prototypes.shape): [64,31] # batch_size : 64, number of class: 30
Please better describe your problem, maybe by showing some examples. I can understand your problem in several different ways.
I guess that you need: torch.Tensor — PyTorch 1.7.1 documentation
Actually, that’s all I explained…
I already have prototypes so I just need to compute centroids of these prototypes according to the labels.
for inputs, labels in train_loader:
prototypes = network(inputs) # network: Resnet + FC layer.
# I want to compute centroids of each class.
I think computing centroid is something like this…
for i, prototypes, labels enumerate(all_samples):
if label == i:
centroid[i] = prototypes / number_of_class_i
But I’m not sure about this.
import torch.nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 10, 3)
self.conv2 = nn.Conv2d(10, 20, 3)
self.conv3 = nn.Conv2d(20, 30, 3)
self.maxpool = nn.MaxPool2d((2,2))
self.linear = nn.Linear(120, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.maxpool(x)
x = F.relu(self.conv2(x))
x = self.maxpool(x)
x = F.relu(self.conv3(x))
x = self.maxpool(x)
x = x.view(x.shape[0], -1)
x = self.linear(x)
return x
network = Model() #example
res = []
for i in range(8): #should be: for inputs, labels in train_loader:
inputs = torch.rand(32, 3, 32, 32) #example
prototypes = network(inputs) # network: Resnet + FC layer.
prototypes = prototypes.permute(1, 0)
res.append(prototypes)
res = torch.cat(res, dim=1)
print(res.shape)
print(res.mean(dim=1).shape)
Here the mean is calculated over all outputs from the network - not the highest one. In the ideal situation - where there are no biases and the dataset is well balanced - all values should be similar.