Custom Collate_fn is very slow compared to normal code

I am trying to work with multi label classification problem with varying input size. Here is my custom collate_fn. When i try to run x, y = next(iter(train_loader)) it is very slow and keeps running.

def my_collate(batch):
    data = [item[0] for item in batch]
    size = max([i.shape[1] for i in data])

    data2 = torch.zeros(len(data), 3, size, size)
    for i, j in enumerate(data):
        data2[i,:,0:j.size()[1], 0:j.size()[1]] = j

    data2 = np.stack(data2, 0)

    target = [item[1] for item in batch]
    target = np.stack(target, 0)
    return data2, target

On the other hand if i modify it and run the size adjustment code after i get data2 and target, it is very fast. How can i solve this issue?

def my_collate(batch):
    data = [item[0] for item in batch]
    
    target = [item[1] for item in batch]
    target = np.stack(target, 0)
    return data, target

Running this code after i get data and target from data, target = next(iter(train_loader)) is much faster

size = max([i.shape[1] for i in data])
data2 = torch.zeros(len(data), 3, size, size)
for i, j in enumerate(data):
    data2[i,:,0:j.size()[1], 0:j.size()[1]] = j

data2 = np.stack(data2, 0)

I figured out the issue. Multi-processing was making it slow. Removed number of workers while making the data loader and issue got resolved

1 Like