I’m trying to solve Cat VS Dog classification problem using pytorch. So I started by creating a DataSet class using the following code:
import torchvision.transforms as transforms
import torch
from torch.utils.data import Dataset, DataLoader
import cv2, os
class dogVScat(Dataset):
def __init__(self, data_dir, transform):
file_names = os.listdir('train')
self.full_file_names = [os.path.join('train', file_name) for file_name in file_names]
self.labels = []
for file in os.listdir('train'):
if file.split('.')[0] == 'cat':
self.labels.append([1, 0])
self.labels.append([0, 1])
self.labels = torch.tensor(self.labels)
self.transform = transform
def __len__(self):
return len(os.listdir('train'))
def __getitem__(self, idx):
image = cv2.imread(self.full_file_names[idx])
image = self.transform(image)
return image, self.labels[idx]
transformer = transforms.Compose(transforms.ToTensor())
dataset = dogVScat('train', transformer)
image, label = dataset[2]
But when I run the code I face this problem:
Traceback (most recent call last):
File "C:/Users/BHAAK/Desktop/ML_PATH/dirty-hands/dirty-hands file 3/Project.py", line 30, in <module>
image, label = dataset[2]
File "C:\Users\BHAAK\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\dataset.py", line 25, in __getitem__
raise NotImplementedError
Please if any one can help me fix this problem I will appreciate it…