I am using a custom dataset that has the following dimensions (3, 64,64) there are 28 classes. The code works absolutely fine with pre-trained ResNet50 for cifar10 and a smaller dataset with 4 classes with similar dimensions. but for 28 classes after some cycles, it gives the following error for ‘ct’ in data loader in def getitem(self, index):
import torch
import torchvision
from torch.utils.data import Dataset
import torchvision.transforms as T
from torchvision.datasets import CIFAR10, ImageFolder
from config import *
class MyDataset(Dataset):
def __init__(self, dataset_name, train_flag, transf):
self.dataset_name = dataset_name
if self.dataset_name == "cifar10":
self.cifar10 = CIFAR10('../cifar10', train=train_flag,
download=True, transform=transf)
if self.dataset_name == "ct":
self.ct = ImageFolder(root='/Dataset/radiology_ai/CT/Split-CT-abd/train', transform=transf)
if self.dataset_name == "satellite":
self.satellite = ImageFolder(root='/Dataset/Satellite/train', transform=transf)
def __getitem__(self, index):
if self.dataset_name == "cifar10":
data, target = self.cifar10[index]
if self.dataset_name == "ct":
data, target = self.ct[index]
if self.dataset_name == "satellite":
data, target = self.satellite[index]
return data, target, index
def __len__(self):
if self.dataset_name == "cifar10":
return len(self.cifar10)
if self.dataset_name == "ct":
return len(self.ct)
if self.dataset_name == "satellite":
return len(self.ct)
def load_dataset(dataset):
train_transform = T.Compose([
T.RandomHorizontalFlip(),
T.RandomCrop(size=64, padding=4),
T.ToTensor(),
T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) # T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100
])
test_transform = T.Compose([
T.ToTensor(),
T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) # T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100
])
if dataset == 'cifar10':
data_train = CIFAR10('../cifar10', train=True, download=True, transform=train_transform)
data_unlabeled = MyDataset(dataset, True, test_transform)
data_test = CIFAR10('../cifar10', train=False, download=True, transform=test_transform)
NO_CLASSES = 10
adden = ADDENDUM
no_train = NUM_TRAIN
if dataset == 'ct':
data_train = ImageFolder(root='/Dataset/radiology_ai/CT/Split-CT-abd/train', transform=train_transform)
data_unlabeled = MyDataset(dataset, True, test_transform)
data_test = ImageFolder(root='/Dataset/radiology_ai/CT/Split-CT-abd/val', transform=test_transform)
NO_CLASSES = 28
adden = ADDENDUM
no_train = NUM_TRAIN
if dataset == 'satellite':
data_train = ImageFolder(root='/Dataset/Satellite/train', transform=train_transform)
data_unlabeled = MyDataset(dataset, True, test_transform)
data_test = ImageFolder(root='/Dataset/Satellite/test', transform=test_transform)
NO_CLASSES = 4
adden = ADDENDUM
no_train = NUM_TRAIN
return data_train, data_unlabeled, data_test, adden, NO_CLASSES, no_train