Custom_collate_fn is not called in dataloader

I have written a collate function to pad and stack data in the batch. when I call the custom_collate_function simply as a function passing a list of inputs , it performs well and outputs what I expect ( it can also stack the entire dataset all together ). However, when it is called in the data loader, it says it cannot stack entries. here is my custom_collate_fn:


from torch_geometric.data import Batch

def collate_my_dataset(samples):
    
    batch_seq_lens = [data[0].size(0) for data in samples ]
    max_seq_len = max(batch_seq_lens)

    token_list = [data[0] for data in samples ]
    labels = [data[1] for data in samples ]

    tokens_padded = []
    for t in token_list:
        if t.shape[0] < max_seq_len :
            padding = torch.zeros((max_seq_len - t.shape[0], t.shape[1])).to(device)
            tokens_padded.append(torch.cat([t, padding], 0))
        else:
            tokens_padded.append(t)
    tokens = torch.stack(tokens_padded, dim=0)

    return Batch(tokens = tokens, labels= torch.tensor(labels).unsqueeze(-1))

and here is how I call the dataloader :

from  torch_geometric import loader
batch_size = 8

div_threshold = int(tu_dataset.__len__()*0.8)
train_dataset = tu_dataset[: div_threshold ]
test_dataset = tu_dataset[div_threshold:]

train_loader = loader.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_my_dataset)
batch = next(iter(test_loader))

the error is something like :
stack expects each tensor to be equal size, but got [31, 64] at entry 0 and [36, 64] at entry 1

PS1: as mentioned what the collate function outputs is what I expect.
PS2 I had the similar question in another topic , but it was in PyG which by nature I couldn’t use PyG dataloader anymore and I have to use native Pytorch dataloader.

Isn’t this the same issue as described in your previous post?
In your code snippet it seems you are still using the PyG DataLoader:

from  torch_geometric import loader
...
train_loader = loader.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_my_dataset)

yes, i noticed it immediately after posting the issue. I should not have used PyG. thank you