Yes. It is not dealing with your exact issue. I was hoping, it might give you some idea.
Please try-out the following collate function:
import torch
import numpy as np
import torch.utils.data as data
def my_collate(batch):
data = torch.stack([item[0].unsqueeze(0) for item in batch], 0)
target = torch.Tensor([item[1] for item in batch])
return [data, target]
class dataset(data.Dataset):
def __init__(self):
super(dataset, self).__init__()
def __len__(self):
return 100
def __getitem__(self, index):
return torch.rand(5, 6), list(range(index))
dataloader = data.DataLoader(dataset=dataset(),
batch_size=4,
shuffle=True,
collate_fn=my_collate, # use custom collate function here
pin_memory=True)
for instance in dataloader:
print(instance[0].shape, len(instance[1]))
for labels in instance[1]:
print('length', len(labels))
raw_input()