Collate function is not called in dataloader

I have a collate function as below:


def collate_mydataset(samples):
    print("hello collate function")
    print(samples)
    num_nodes_list = [data[0].size(0) for data in samples]
    max_num_nodes = max(num_nodes_list)
    num_edges_list = [data[2].size(0) for data in samples]
    max_num_edges = max(num_edges_list)
    features_list = [data[0] for data in samples]               #node features
    edge_indices_list = [data[1] for data in samples]          
    edge_features_list = [data[2] for data in samples]
    graph_labels_list = [data[3] for data in samples]
    m_list = [data[4] for data in samples]
    

    features_padded = []
    for feature in features_list:
        num_nodes = feature.shape[0]
        if num_nodes < max_num_nodes:
            padding = torch.zeros((max_num_nodes - num_nodes, feature.shape[1]))
            features_padded.append(torch.cat([feature, padding], 0))
        else:
            features_padded.append(feature)
    features = torch.stack(features_padded, dim=0)
    

    edge_indices_padded = []
    for edge_indices in edge_indices_list:
        num_edges = edge_indices.shape[1]
        if num_edges < max_num_edges:
            padding = torch.zeros((2, max_num_edges - num_edges))
            edge_indices_padded.append(torch.cat([edge_indices, padding], 1))
        else:
            edge_indices_padded.append(edge_indices)
    edge_indices = torch.stack(edge_indices_padded, dim=1)
    
    
    edge_features_padded = []
    for e_feature in edge_features_list:
        num_edges = e_feature.shape[0]
        if num_edges < max_num_edges:
            padding = torch.zeros((max_num_edges - num_edges, e_feature.shape[1]))
            edge_features_padded.append(torch.cat([e_feature, padding], 0))
        else:
            edge_features_padded.append(e_feature)
    edge_features = torch.stack(edge_features_padded, dim=0)



    graph_labels = torch.stack(graph_labels_list, dim=0)
    
    m_padded = []
    for m in m_list:
        num_nodes = m.shape[0]
        if num_nodes < max_num_nodes:
            padding = torch.zeros((max_num_nodes - num_nodes, m.shape[1]))
            m_padded.append(torch.cat([m, padding], 0))
        else:
            m_padded.append(m)
    m = torch.stack(m_padded, dim=0)

    
    return [features, edge_indices, edge_features, graph_labels, m]

when I pass two samples of my graph dataset as:

out=collate_mydataset(tu_dataset[0:2])

it prints the out put and works fine

but when I pass the function in dataloader :

from torch_geometric import loader
torch.manual_seed(42)
batch_size=10
div_threshold = int(tu_dataset.__len__()*0.8)
train_dataset = tu_dataset[: div_threshold ]
test_dataset = tu_dataset[int(tu_dataset.__len__()*0.8):]

train_loader =loader.DataLoader(train_dataset, batch_size=batch_size, shuffle=False,collate_fn=collate_fn)

it doesn’t even call the function and says
stack expects each tensor to be equal size, but got [32, 9] at entry 0 and [15, 9] at entry 1

which in this case is saying it can’t stack node features of two graphs ( but if it calls the function this problem should be solved because in the function it handles padding and stacking )

what could be the issue?

This seems to be expected, since the PyTorch-Geometric DataLoader implementation will delete your custom collate_fn and replace it with their Collater class as seen here.

yes , the error directs me to the same part of the code, but i do not know how should the collate function be structured ,
and actually the main question is what should i do so that when the custom collate function is replaced, it works fine.
PS: in other resources (on text data , not graph) they were simply passing custom collate function’s name as i have done here.

train_loader =loader.DataLoader(train_dataset, batch_size=batch_size, shuffle=False,collate_fn=collate_mydataset)

I don’t know why PyG replaces the custom collate_fn, but @rusty1s would know.

I know this an old thread, but it is interesting because I stumbled upon the same problem until I found the answer here. At this point what are the benefits of using the DataLoader from pyg over the DataLoader from torch?
@rusty1s

1 Like