AttributeError: 'ConcatDataset' object has no attribute 'classes'

Goal:

  • create a augmented dataset and concatenate it to the original dataset (CIFAR10)

Attempt:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable
import torch.optim
import augment
from torch.utils.data import ConcatDataset
from torchvision.datasets import CIFAR10

import numpy as np
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

import dataset
import model_store

# store hyper parameters
hparams  = {
    "model": "CNN",
    "detaset": "CIFAR-10",
    "optimizer": "Adam",
    "momentum": 0.9,
    "epochs": 1,
    "train_batch_size": 32,
    "eval_batch_size": 32,
    "lr": 1e-3,
    "checkpoint": 1000,
}

# avalable GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# get dataset
train, test = dataset.get_dataset(hparams["detaset"])

# include augmentations methods
transform_train = augment.get_train_transforms()
transform_test = augment.get_test_transforms()

# Augment the original dataset
augmented_train = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
augmented_test = CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_dataset = [train, augmented_train]
test_dataset = [test, augmented_test]

train = ConcatDataset(train_dataset)
test = ConcatDataset(test_dataset)

Error:

AttributeError                            Traceback (most recent call last)
Cell In[10], line 65
     53 train_loader = torch.utils.data.DataLoader(
     54                 train,
     55                 batch_size=hparams["train_batch_size"],
     56                 shuffle=True,
     57                 num_workers=2)
     59 test_loader = torch.utils.data.DataLoader(
     60                 test,
     61                 batch_size=hparams["eval_batch_size"],
     62                 shuffle=False,
     63                 num_workers=2)
---> 65 hparams["n_classes"] = len(train.classes)
     66 hparams["input_shape"] = train[0][0].shape
     68 # get model

AttributeError: 'ConcatDataset' object has no attribute 'classes'

Could someone please help me with this?
Thanks

For reference this is augment.py

import torchvision.transforms as transforms

def get_train_transforms():
    return transforms.Compose([
        transforms.RandomResizedCrop(32, scale=(0.08, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

def get_test_transforms():
    return transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

and dataset.py

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import torchvision.transforms as transforms


DATASETS = [
    "CIFAR-10", 
]

def get_dataset(dataset:str, dir="./data"):
    '''return dataset as a Dataset class of pytorch'''
    if dataset not in DATASETS:
        raise NotImplementedError("Dataset not found: {}".format(dataset))
    
    if dataset == "CIFAR-10":
        transform = transforms.Compose(
            [transforms.ToTensor(), 
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train =  datasets.CIFAR10(
            root=dir, 
            train=True, 
            download=True, 
            transform=transform)
        
        test =  datasets.CIFAR10(
            root=dir, 
            train=False, 
            download=True, 
            transform=transform)
        
    return train, test

The error is expected, since you need to access the internal dataset first before accessing its attribute:

dataset = datasets.MNIST(
    root="./data",
    download=False,
    transform=transforms.ToTensor()
)
print(dataset.classes)
# ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

cat_dataset = torch.utils.data.ConcatDataset([dataset])

# fails
cat_dataset.classes
# AttributeError: 'ConcatDataset' object has no attribute 'classes'

# works
print(cat_dataset.datasets[0].classes)
# ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']