Here is my code :
class PlantDataset(Dataset):
def init(self,df,transform=None):
self.images = df[‘image’].values
self.labels = df[‘labels’].values
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self,idx):
image_id = self.images[idx]
labels = self.labels[idx]
image_path = os.path.join(train_img_path , image_id)
image = cv2.imread(image_path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
image = self.transform(image=image)
return image,labels
def get_transform(phase: str):
if phase == ‘train’:
return A.Compose([
A.Resize(224,224),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.Normalize([0.4916,0.4498,0.400],[0.2474,0.2362,0.2322]),
ToTensorV2(),
])
else:
return A.Compose([
A.Resize(224,224),
A.Normalize([0.4916,0.4498,0.400],[0.2474,0.2362,0.2322]),
ToTensorV2(),
])
train = PlantDataset(df_train,transform = get_transform(‘train’))
valid = PlantDataset(df_valid,transform = get_transform(‘valid’))
trainloader = torch.utils.data.DataLoader(train,batch_size=64,shuffle=True)
validloader = torch.utils.data.DataLoader(valid,batch_size=64)
images,labels = next(iter(trainloader))