Goal:
- create a augmented dataset and concatenate it to the original dataset (CIFAR10)
Attempt:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable
import torch.optim
import augment
from torch.utils.data import ConcatDataset
from torchvision.datasets import CIFAR10
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import dataset
import model_store
# store hyper parameters
hparams = {
"model": "CNN",
"detaset": "CIFAR-10",
"optimizer": "Adam",
"momentum": 0.9,
"epochs": 1,
"train_batch_size": 32,
"eval_batch_size": 32,
"lr": 1e-3,
"checkpoint": 1000,
}
# avalable GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# get dataset
train, test = dataset.get_dataset(hparams["detaset"])
# include augmentations methods
transform_train = augment.get_train_transforms()
transform_test = augment.get_test_transforms()
# Augment the original dataset
augmented_train = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
augmented_test = CIFAR10(root='./data', train=False, download=True, transform=transform_test)
train_dataset = [train, augmented_train]
test_dataset = [test, augmented_test]
train = ConcatDataset(train_dataset)
test = ConcatDataset(test_dataset)
Error:
AttributeError Traceback (most recent call last)
Cell In[10], line 65
53 train_loader = torch.utils.data.DataLoader(
54 train,
55 batch_size=hparams["train_batch_size"],
56 shuffle=True,
57 num_workers=2)
59 test_loader = torch.utils.data.DataLoader(
60 test,
61 batch_size=hparams["eval_batch_size"],
62 shuffle=False,
63 num_workers=2)
---> 65 hparams["n_classes"] = len(train.classes)
66 hparams["input_shape"] = train[0][0].shape
68 # get model
AttributeError: 'ConcatDataset' object has no attribute 'classes'
Could someone please help me with this?
Thanks
For reference this is augment.py
import torchvision.transforms as transforms
def get_train_transforms():
return transforms.Compose([
transforms.RandomResizedCrop(32, scale=(0.08, 1.0)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
def get_test_transforms():
return transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
and dataset.py
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
DATASETS = [
"CIFAR-10",
]
def get_dataset(dataset:str, dir="./data"):
'''return dataset as a Dataset class of pytorch'''
if dataset not in DATASETS:
raise NotImplementedError("Dataset not found: {}".format(dataset))
if dataset == "CIFAR-10":
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train = datasets.CIFAR10(
root=dir,
train=True,
download=True,
transform=transform)
test = datasets.CIFAR10(
root=dir,
train=False,
download=True,
transform=transform)
return train, test