I have written a collate function to pad and stack data in the batch. when I call the custom_collate_function simply as a function passing a list of inputs , it performs well and outputs what I expect ( it can also stack the entire dataset all together ). However, when it is called in the data loader, it says it cannot stack entries. here is my custom_collate_fn:
from torch_geometric.data import Batch
def collate_my_dataset(samples):
batch_seq_lens = [data[0].size(0) for data in samples ]
max_seq_len = max(batch_seq_lens)
token_list = [data[0] for data in samples ]
labels = [data[1] for data in samples ]
tokens_padded = []
for t in token_list:
if t.shape[0] < max_seq_len :
padding = torch.zeros((max_seq_len - t.shape[0], t.shape[1])).to(device)
tokens_padded.append(torch.cat([t, padding], 0))
else:
tokens_padded.append(t)
tokens = torch.stack(tokens_padded, dim=0)
return Batch(tokens = tokens, labels= torch.tensor(labels).unsqueeze(-1))
and here is how I call the dataloader :
from torch_geometric import loader
batch_size = 8
div_threshold = int(tu_dataset.__len__()*0.8)
train_dataset = tu_dataset[: div_threshold ]
test_dataset = tu_dataset[div_threshold:]
train_loader = loader.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_my_dataset)
batch = next(iter(test_loader))
the error is something like :
stack expects each tensor to be equal size, but got [31, 64] at entry 0 and [36, 64] at entry 1
PS1: as mentioned what the collate function outputs is what I expect.
PS2 I had the similar question in another topic , but it was in PyG which by nature I couldn’t use PyG dataloader anymore and I have to use native Pytorch dataloader.