Custom Collate Fn not called

i want my output of loader to be a list of lists such that each inner list is of frames of videos through which i can traverse during inference in a loop.
But my collate function seems to be not invoked. i dont see X getting printed

for i in vidloader_test:
    print(len(i))
    break
def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    target = torch.LongTensor(target)
    print('x')
    return [data, target]
vidset_test = vidSet(Path_test)
vidset_valid = vidSet(Path_train)
vidset_train = vidSet(Path_train)

vidloader_test= torch.utils.data.DataLoader(vidset_test, batch_size=3, shuffle=True,collate_fn=my_collate)
vidloader_valid= DataLoader(vidset_valid, batch_size=64, shuffle=False)
vidloader_train= DataLoader(vidset_train, batch_size=64, shuffle=False)
class vidSet(Dataset):
    def __init__(self, videos_path):
        self.video_paths = videos_path.ls()
        self.root=videos_path
        self.c=2
        #self.itemlist=self.video_paths

        #self.caps = [cv2.VideoCapture(str(split_path)) for video_path in self.video_paths]
        #self.images = [[capid, framenum] for capid, cap in enumerate(self.caps) for framenum in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)))]
    
        #self.labels = [label for i in range(len(self.images))] # Whatever your needs are
    
    def __len__(self):
         return len(self.video_paths)

    def __getitem__(self, idx):
        
        
       #capid, framenum = self.images[idx]
        faces_list = read_video(mtcnn,path=self.video_paths[idx])
        
        if faces_list is None :
            
            print('None')
        
        
       #cap.set(cv2.CAP_PROP_POS_FRAMES, framenum)
       #res, frame = cap.read()

       #img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
       #label = self.labels[idx]
    
       #img_tensor = torch.from_numpy(img).permute(2,0,1).float() # /255, -mean, /std ... do your things with the image
       #label_tensor = torch.as_tensor(label)
        print('sampleid',idx,'length of frame list returned ',len(faces_list))
        return faces_list,0
sampleid 269 length of frame list returned  2
sampleid 222 length of frame list returned  3
sampleid 236 length of frame list returned  2

Facing the same issue, any update on this?

Can you post a minimal example so we can reproduce this issue?

My code is the same exact logic as above, I’m iterating through my dataloader that uses a custom collate_fn, and I have a print statement that I’m hoping/expecting to be called, but I don’t see that happening.

I can’t reproduce this issue with similar code snippet on my end:

import torch

class MyDataset(torch.utils.data.Dataset):
    def __init__(self):
        pass

    def __getitem__(self, idx):
        return (torch.tensor(idx, dtype=torch.float), idx%2)

    def __len__(self):
        return 1000

def mycollate(batch):
    print("hello")
    data = [item[0] for item in batch]
    label = [True if item[1] > 0 else False for item in batch]
    return data, label

dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 64, shuffle=True, collate_fn=mycollate)
for i in dataloader:
    print(len(i))
    print(i)
    break

this outputs

hello
2
([tensor(532.), tensor(948.), tensor(348.), tensor(815.), tensor(564.), tensor(455.), tensor(610.), tensor(888.), tensor(163.), tensor(728.), tensor(844.), tensor(462.), tensor(106.), tensor(982.), tensor(156.), tensor(814.), tensor(822.), tensor(971.), tensor(461.), tensor(969.), tensor(750.), tensor(845.), tensor(147.), tensor(112.), tensor(380.), tensor(384.), tensor(912.), tensor(770.), tensor(761.), tensor(406.), tensor(643.), tensor(856.), tensor(186.), tensor(577.), tensor(339.), tensor(360.), tensor(335.), tensor(408.), tensor(749.), tensor(485.), tensor(68.), tensor(159.), tensor(509.), tensor(932.), tensor(904.), tensor(891.), tensor(832.), tensor(946.), tensor(790.), tensor(847.), tensor(981.), tensor(923.), tensor(554.), tensor(288.), tensor(215.), tensor(879.), tensor(262.), tensor(361.), tensor(472.), tensor(908.), tensor(138.), tensor(531.), tensor(840.), tensor(772.)], [False, False, False, True, False, True, False, False, True, False, False, False, False, False, False, False, False, True, True, True, False, True, True, False, False, False, False, False, True, False, True, False, False, True, True, False, True, False, True, True, False, True, True, False, False, True, False, False, False, True, True, True, False, False, True, True, False, True, False, False, False, True, False, False])

Haha actually on my end, I was using an erroneous assert statement, so it’s actually working for me too :slight_smile: I really appreciate your help though!