Hi everyone, I’m trying to do some classification for the CelebA dataset, but I’m encountering an error with trying to load my data.
The code is:
class MultiClassCelebA(Dataset):
def __init__(self, dataframe, folder_dir, transform=None, target_transform=None):
self.dataframe = dataframe
self.folder_dir = folder_dir
self.transform = transform
self.target_transform = target_transform
self.file_names = dataframe.index
self.labels = dataframe.labels.values.tolist()
def __len__(self):
return len(self.dataframe)
def __getitem__(self, index):
image = Image.open(os.path.join(self.folder_dir, self.file_names[index]))
label = self.labels[index][0]
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
label = self.target_transform(label)
return image, label
tfms = transforms.Compose([transforms.Resize((256, 256)),
transforms.PILToTensor()])
train_dl = MultiClassCelebA(train_df, celeb_path + '/train/', transform=tfms)
val_dl = MultiClassCelebA(val_df, celeb_path + '/val/', transform=tfms)
train_dataloader = DataLoader(train_dl, shuffle = True, batch_size = 32)
val_dataloader = DataLoader(val_dl, shuffle = True, batch_size = 32)
next(iter(train_dataloader))
but I’m getting the above TypeError from the last line. I’ve checked the output of the dataset, e.g. for index 2
type(train_dl[2][0])
and it confirms that it is a torch.Tensor. Why is the dataloader seeing it as an Object and how can I fix this? Thanks