Argmax translation to which category


(Romano) #1

I made a custom dataset with dataloader, which was about 3 different categories images. I used vgg16 to predict which category was in the image.

If i want to predict a single image however, i would get back something like this:

tensor([[-0.0559, -1.6212, -0.3467]], grad_fn=)

How would I know which of the categories correspons with index 0 or index 1?

I’ve seen other problems where do something like this:

net = torch.load(‘pytorch_Network2.h5’)
idx_to_class = {
0: ‘airplane’,
1: ‘automobile’,
2: ‘bird’,
}
output = net(img)

pred = torch.argmax(output, 1)

for p in pred:

cls = idx_to_class[p.item()]

print(cls)

But then, how would you know index 0 is a plane? Does that have to do with how you build your dataset? in that case, This is how i’ve built it:

class SophistyDataSet(Dataset):
    """Dataset wrapping images and target labels for 3 categories 'Modern','Vintage','Classic'

    Arguments:
        A CSV file path
        Path to image folder
        Extension of images
        PIL transforms
    """

    def __init__(self, directory_list, img_path):
        raw_images_list = self.prepare_dataset(directory_list)
        images_df = pd.DataFrame(raw_images_list, columns=['name', 'tag'])
        self.mlb = MultiLabelBinarizer()
        self.transform = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor()])
        self.img_path = img_path
        self.X_train = images_df['name']
        self.y_train = self.mlb.fit_transform(images_df['tag'].str.split()).astype(np.float32)

    def __getitem__(self, index):
        img = Image.open(
            self.img_path + self.determine_imagename(self.X_train[index]) + '/' + self.X_train[index]).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        label = self.y_train[index]
        return img, label

    def __len__(self):
        return len(self.X_train.index)```

(Thomas V) #2

You need to keep track of this mapping.

Best regards

Thomas