Your code seems to work fine and is filtering out the None
samples:
def custom_collate(original_batch):
filtered_data = []
filtered_target = []
for item in original_batch:
none_found = False
if "IMG1" in item[0].keys():
if item[0]["IMG1"] is None:
print('none found in IMG1')
none_found = True
if "IMG2" in item[0].keys():
if item[0]["IMG2"] is None:
none_found = True
print('none found in IMG2')
if not none_found:
filtered_data.append(item[0])
filtered_target.append(item[1])
return filtered_data, filtered_target
class MyDataset(torch.utils.data.Dataset):
def __init__(self):
self.len = 10
def __getitem__(self, index):
batch = {
'IMG1': torch.ones(1, 2, 2) * index if torch.randint(0, 2, (1,)) == 0 else None,
'IMG2': torch.ones(1, 2, 2) * index + 0.1 if torch.randint(0, 2, (1,)) == 0 else None
}
target = torch.randn(1,)
print(batch)
return batch, target
def __len__(self):
return self.len
dataset = MyDataset()
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=custom_collate)
for a, b in loader:
print('='*10)
print(a)
print(b)
Output:
{'IMG1': tensor([[[0., 0.],
[0., 0.]]]), 'IMG2': tensor([[[0.1000, 0.1000],
[0.1000, 0.1000]]])}
{'IMG1': tensor([[[1., 1.],
[1., 1.]]]), 'IMG2': None}
none found in IMG2
==========
[{'IMG1': tensor([[[0., 0.],
[0., 0.]]]), 'IMG2': tensor([[[0.1000, 0.1000],
[0.1000, 0.1000]]])}]
[tensor([1.5672])]
{'IMG1': tensor([[[2., 2.],
[2., 2.]]]), 'IMG2': None}
{'IMG1': tensor([[[3., 3.],
[3., 3.]]]), 'IMG2': None}
none found in IMG2
none found in IMG2
==========
[]
[]
{'IMG1': tensor([[[4., 4.],
[4., 4.]]]), 'IMG2': None}
{'IMG1': None, 'IMG2': None}
none found in IMG2
none found in IMG1
none found in IMG2
==========
[]
[]
{'IMG1': None, 'IMG2': tensor([[[6.1000, 6.1000],
[6.1000, 6.1000]]])}
{'IMG1': None, 'IMG2': None}
none found in IMG1
none found in IMG1
none found in IMG2
==========
[]
[]
{'IMG1': tensor([[[8., 8.],
[8., 8.]]]), 'IMG2': None}
{'IMG1': None, 'IMG2': None}
none found in IMG2
none found in IMG1
none found in IMG2
==========
[]
[]