Hi @ptrblck . I fixed the shape mismatch problem, it was due to a bug in the code.
When I’m creating the dataest, I can check that the dataset outputs three images. But when I wrap it in a dataloader with a batch size more than 1, I get an index error only when shuffle=True
.
class MyDataset(Dataset):
def __init__(self, video_path, transform, nb_frames=3):
self.nb_frames = nb_frames
self.video_path = video_path
classes, class_to_idx = self._find_classes(self.video_path)
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
folders = []
for target in sorted(class_to_idx.keys()):
d = os.path.join(video_path,target)
if not os.path.isdir(d):
continue
for folder in sorted(os.listdir(d)):
folders.append([d + "/" + folder,class_to_idx[target]])
# The folders array is of the format [['dir/train/fake/001_003',0],['dir/train/real/001',1], ... etc]
# Now we have to store the individual images in order to feed them into the network
self.samples = []
for f in folders:
fold = sorted(os.listdir(f[0])) #f[0] is the name of the folder, f[1] is the target
idx = f[1]
# Crop data to multiple of nb_frames
fold1 = fold[:-(len(fold)%self.nb_frames)] if len(fold)%self.nb_frames!=0 else fold
# print([fold1,idx])
# break
for file in range(len(fold1)):
fold1[file] = f[0] + "/" + fold1[file]
self.samples.append([fold1,idx])
#self.samples now contain data in the form of - [['../001/1.jpg','../001/2.jpg','001/3.jpg',...],1],[['../001_003/1.jpg',...],0]]
# calculate lengths
self.lens = [len(d[0])//self.nb_frames for d in self.samples]
# calculate offsets
self.offsets = np.concatenate(([0], np.cumsum(self.lens[:-1])))
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args: dir (string): Root directory path.
Returns: tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def pil_loader(self, path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def transform_self(self, image):
# Resize
resize = transforms.Resize(size=(224, 224))
image = resize(image)
# Transform to tensor
image = TF.to_tensor(image)
return image
def __getitem__(self, index):
# subtract offset
print('index: {}'.format(index))
# get corresponding video file
found = False
for i, offset in enumerate(self.offsets):
if index < offset:
#print('subtracting {} from index'.format(self.offsets[i-1]))
index -= self.offsets[i-1]
index *= self.nb_frames
found = True
break
# handle last video separately
if not found:
index -= self.offsets[-1]
index *= self.nb_frames
i += 1
# select correspondind data
print('selecing video {}'.format(i-1))
samples = self.samples[i-1][0]
target = self.samples[i-1][1] # for a particular index, the target remains the same for the three images
#print(target)
# get frames
print('reading frames {}'.format([idx for idx in range(index, index+self.nb_frames)]))
x = []
for idx in range(index, index+self.nb_frames):
path = samples[idx]
tmp = self.pil_loader(path)
if self.transform is not None :
tmp = self.transform(tmp)
# tmp = self.transform_self(tmp)
x.append(tmp)
#print(len(x))
#x = torch.cat(x)
#print(x.shape)
return x[0],x[1],x[2],target
def __len__(self):
return np.sum(self.lens)*self.nb_frames
This is the dataset class.
image_datasets = {x: MyDataset(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val', 'test']}
Here, when I do
for data in image_datasets['train']:
print(data[0].shape,data[1].shape,data[2].shape)
#this gives the perfectly okay result -
index: 0
selecing video 0
reading frames [0, 1, 2]
torch.Size([3, 224, 224]) torch.Size([3, 224, 224]) torch.Size([3, 224, 224])
#and so on
But when I wrap it into the dataloader -
loader = DataLoader(
image_datasets['train'],
batch_size=64,
shuffle=True
)
for *data, target in loader:
print(target)
#this gives the output -
index: 30630
selecing video 589
reading frames [35556, 35557, 35558]
#and then the error -
104 x = []
105 for idx in range(index, index+self.nb_frames):
--> 106 path = samples[idx]
107 tmp = self.pil_loader(path)
108
IndexError: list index out of range
I can’t understand why the dataloader shows the index error only shuffle = True