What is the simplest way to change class_to_idx attribute?

Hello everyone, iā€™m defining my data and dataloader in the currently way:

	data = {
		datasets.ImageFolder(root=train_path, transform=image_transforms['train'], ),
		datasets.ImageFolder(root=valid_path, transform=image_transforms['valid']),
	dataloaders = {
		'train': DataLoader(data['train'], batch_size=batch_size, shuffle=True),
		'val': DataLoader(data['valid'], batch_size=batch_size, shuffle=True)

In my problem i have four classes (A, AB, AC, ABC). When I print data['train'].class_to_idx i get the following output:

{'A': 0, 'AB': 1, 'ABC': 2, 'AC': 3}

All that i wanna to do is change index of AC to 2, and ABC to 3. and get the following result

{'A': 0, 'AB': 1, 'ABC': 3, 'AC': 2}

What is the simplest way to do that, and what other changes i have to do along my code?

Thanks :smiley:

I would recommend to create a custom Dataset and maybe just reuse some parts of the ImageFolder dataset from here.
Changing this attribute after the ImageFolder was created seems to be wrong, since class_to_idx will be used to create the dataset as seen here.


Would you like to share your solution?

you can create custom CustomImageFolder loader class and override find_classes function like this:

from typing import Tuple, List, Dict

from torchvision.datasets import ImageFolder

def classes_to_idx() -> dict:
    return load_json_file("class_indexes.json")

class CustomImageFolder(ImageFolder):
    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        Override this method to load from setting file instead of scanning directory
        classes = list(classes_to_idx().keys())
        classes_to_idx = classes_to_idx()
        return classes, classes_to_idx

Here is sample class_indexes.json json file mapping class to index (ā€œ1ā€, ā€œ2ā€, ā€¦ā€œ13ā€ is class folder name; 0, 1, ā€¦, 12 is index of class):

    "1": 0,
    "2": 1,
    "3": 2,
    "4": 3,
    "5": 4,
    "6": 5,
    "7": 6,
    "8": 7,
    "9": 8,
    "10": 9,
    "11": 10,
    "12": 11,
    "13": 12