Custom collate function

Hi all,
I have a particular use case on which I would run a custom collate function.

I have a target (a float) and an input to feed the model.
For each case, my input is made of a dictionary, with 2 images inside (IMG1, IMG2).

I want to remove the cases where at least one of the two images is not available (None).
I wrote the function below but it doesn’t work:

def custom_collate(original_batch):

filtered_data = []
filtered_target = []

for item in original_batch:
    none_found = False
    if "IMG1" in item[0].keys():
        if item[0]["IMG1"] is None:
            none_found = True
    if "IMG2" in item[0].keys():
        if item[0]["IMG2"] is None:
            none_found = True

    if not none_found:
        filtered_data.append(item[0])
        filtered_target.append(item[1])

return filtered_data, filtered_target

Do you have any suggestions?

Thanks!

Your code seems to work fine and is filtering out the None samples:

def custom_collate(original_batch):
    filtered_data = []
    filtered_target = []

    for item in original_batch:
        none_found = False
        if "IMG1" in item[0].keys():
            if item[0]["IMG1"] is None:
                print('none found in IMG1')
                none_found = True
        if "IMG2" in item[0].keys():
            if item[0]["IMG2"] is None:
                none_found = True
                print('none found in IMG2')

        if not none_found:
            filtered_data.append(item[0])
            filtered_target.append(item[1])

    return filtered_data, filtered_target


class MyDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.len = 10
        
    def __getitem__(self, index):
        batch = {
            'IMG1': torch.ones(1, 2, 2) * index if torch.randint(0, 2, (1,)) == 0 else None,
            'IMG2': torch.ones(1, 2, 2) * index + 0.1 if torch.randint(0, 2, (1,)) == 0 else None
        }
        target = torch.randn(1,)
        print(batch)
        return batch, target
    
    def __len__(self):
        return self.len
    
dataset = MyDataset()
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=custom_collate)

for a, b in loader:
    print('='*10)
    print(a)
    print(b)

Output:

{'IMG1': tensor([[[0., 0.],
         [0., 0.]]]), 'IMG2': tensor([[[0.1000, 0.1000],
         [0.1000, 0.1000]]])}
{'IMG1': tensor([[[1., 1.],
         [1., 1.]]]), 'IMG2': None}
none found in IMG2
==========
[{'IMG1': tensor([[[0., 0.],
         [0., 0.]]]), 'IMG2': tensor([[[0.1000, 0.1000],
         [0.1000, 0.1000]]])}]
[tensor([1.5672])]
{'IMG1': tensor([[[2., 2.],
         [2., 2.]]]), 'IMG2': None}
{'IMG1': tensor([[[3., 3.],
         [3., 3.]]]), 'IMG2': None}
none found in IMG2
none found in IMG2
==========
[]
[]
{'IMG1': tensor([[[4., 4.],
         [4., 4.]]]), 'IMG2': None}
{'IMG1': None, 'IMG2': None}
none found in IMG2
none found in IMG1
none found in IMG2
==========
[]
[]
{'IMG1': None, 'IMG2': tensor([[[6.1000, 6.1000],
         [6.1000, 6.1000]]])}
{'IMG1': None, 'IMG2': None}
none found in IMG1
none found in IMG1
none found in IMG2
==========
[]
[]
{'IMG1': tensor([[[8., 8.],
         [8., 8.]]]), 'IMG2': None}
{'IMG1': None, 'IMG2': None}
none found in IMG2
none found in IMG1
none found in IMG2
==========
[]
[]

Thanks for answering @ptrblck !
The problem is that if I use this function, I get the error:

TypeError: list indices must be integers or slices, not str

If I do not use that function all runs smoothly.

I am using lightning pytorch.

Thanks a lot.

Are you seeing this error using my code snippet?
If so, could you post the PyTorch version you are using?
If not, could you post a minimal, executable code snippet to reproduce this error?

Your toy example runs without problem.
Unfortunately, creating a snippet is really difficult, but I found the cause of the error (but I don’t know how to fix it).

If I do not set the custom collate function, the model is fed with a dictionary (in this case with a single element “IMG1”).
x[“IMG1”].shape returns [12,1,40,40,40] (12 batch, 1 channel, 3D volume).
The model gets trained, it’s all ok.

If I set the function I reported in the first message, I get something quite different as input for the model.
In this case, I get a list of 12 dictionaries.
Each dictionary has the key “IMG1” that contains a tensor of shape [1, 40, 40, 40] (I get this running x[0][“IMG1”].shape)
In other words, I get a list of shape batch_size, containing a single dictionary with a tensor having no batch dimension.

I hope I exposed in a clear way my case.

Thanks a lot @ptrblck !

I think I figured it out, it returns what I asked (I have to rewrite the function to package the output in the proper way).

I wrote the function in the right way @ptrblck

I think it is a little bit strange that the function gets something as input and must return a structure that is different. I was expected to return the data in the same structure as the input.

The documentation does not report this aspect.

Thanks for helping in debugging!

Paolo