I want to classify cat and dog images.
So I downloaded images from Kaggle, and divided into train and validation set.
here is my code…
class TrainImageFolder(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.label = [id_ for id_ in os.listdir(path)]
self.ids = list()
self.transform = transform
for i in self.label:
for j in os.listdir(os.path.join(path, i)):
self.ids += [i + '_' + j]
def __len__(self):
len(self.ids)
def __getitem__(self, index):
image = Image.open(self.ids[index][4:])
image = image.convert('RGB')
image = transforms.ToTensor()
if (self.ids[index][:3] == 'Cat'):
label = 0
else:
label = 1
if self.transform:
image = self.transform(image)
return image, label
transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
dataset = TrainImageFolder('PetImages/train')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
for i, (img, bbox, label, difficult) in tqdm(enumerate(dataloader)):
continue
and error is…
Traceback (most recent call last):
File "C:/Users/ge971/PycharmProjects/myVGG16/dataset.py", line 49, in <module>
for i, (img, bbox, label, difficult) in tqdm(enumerate(dataloader)):
File "C:\Users\ge971\miniconda3\envs\torch17\lib\site-packages\tqdm\std.py", line 1166, in __iter__
for obj in iterable:
File "C:\Users\ge971\miniconda3\envs\torch17\lib\site-packages\torch\utils\data\dataloader.py", line 435, in __next__
data = self._next_data()
File "C:\Users\ge971\miniconda3\envs\torch17\lib\site-packages\torch\utils\data\dataloader.py", line 474, in _next_data
index = self._next_index() # may raise StopIteration
File "C:\Users\ge971\miniconda3\envs\torch17\lib\site-packages\torch\utils\data\dataloader.py", line 427, in _next_index
return next(self._sampler_iter) # may raise StopIteration
File "C:\Users\ge971\miniconda3\envs\torch17\lib\site-packages\torch\utils\data\sampler.py", line 227, in __iter__
for idx in self.sampler:
File "C:\Users\ge971\miniconda3\envs\torch17\lib\site-packages\torch\utils\data\sampler.py", line 67, in __iter__
return iter(range(len(self.data_source)))
TypeError: 'NoneType' object cannot be interpreted as an integer
Process finished with exit code 1
How should I fix my code? Please help me…