Loading data in list of lists form

Hello!

I want my dataloader to return me a list of size 3, in which each element is likewise a list, in the following form:

[[tensor1, tensor2], [tensor3, tensor4], [tensor5, tensor6, tensor7]]
tensor1.shape = tensor3.shape
tensor2.shape = tensor4.shape
tensor1.shape != tensor2.shape
tensor5.shape != tensor6.shape != tensor7.shape

As you can see, it’s a bit intricate format but it’s needed this way. I am using PytorchLightning and in the .fit function, I receive the following error: each element in list of batch should be of equal size. If I use colllate_fn=lambda x:x the problem doesn’t appear but the format of my batch changes and I receive a list with all elements in the batch (len(batch) = batch_size).
Nevertheless, what I want is that each element of the list I show above have a dimension batch_size x d.

How can I do it? Cannot this be done with PyTorch?

Thanks a lot!

Your approach seems to work for me:

class MyDataset(Dataset):
    def __init__(self):
        pass
    
    def __len__(self):
        return 10
    
    def __getitem__(self, index):
        tensor1 = torch.randn(2)
        tensor2 = torch.randn(3)
        tensor3 = torch.randn(2)
        tensor4 = torch.randn(3)
        tensor5 = torch.randn(4)
        tensor6 = torch.randn(5)
        tensor7 = torch.randn(6)
        return [[tensor1, tensor2], [tensor3, tensor4], [tensor5, tensor6, tensor7]]
    
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=5)

for a, b, c in loader:
    print(a[0].shape) # torch.Size([5, 2])
    print(a[1].shape) # torch.Size([5, 3])
    print(b[0].shape) # torch.Size([5, 2])
    print(b[1].shape) # torch.Size([5, 3])
    print(c[0].shape) # torch.Size([5, 4])
    print(c[1].shape) # torch.Size([5, 5])
    print(c[2].shape) # torch.Size([5, 6])

Hi, thanks for answering!

It doesn’t work when I used PytorchLightning, I receive the mentioned error. Sorry, I should have specified, I thought it was something that would happen no matter what.