I have a collate function as below:
def collate_mydataset(samples):
print("hello collate function")
print(samples)
num_nodes_list = [data[0].size(0) for data in samples]
max_num_nodes = max(num_nodes_list)
num_edges_list = [data[2].size(0) for data in samples]
max_num_edges = max(num_edges_list)
features_list = [data[0] for data in samples] #node features
edge_indices_list = [data[1] for data in samples]
edge_features_list = [data[2] for data in samples]
graph_labels_list = [data[3] for data in samples]
m_list = [data[4] for data in samples]
features_padded = []
for feature in features_list:
num_nodes = feature.shape[0]
if num_nodes < max_num_nodes:
padding = torch.zeros((max_num_nodes - num_nodes, feature.shape[1]))
features_padded.append(torch.cat([feature, padding], 0))
else:
features_padded.append(feature)
features = torch.stack(features_padded, dim=0)
edge_indices_padded = []
for edge_indices in edge_indices_list:
num_edges = edge_indices.shape[1]
if num_edges < max_num_edges:
padding = torch.zeros((2, max_num_edges - num_edges))
edge_indices_padded.append(torch.cat([edge_indices, padding], 1))
else:
edge_indices_padded.append(edge_indices)
edge_indices = torch.stack(edge_indices_padded, dim=1)
edge_features_padded = []
for e_feature in edge_features_list:
num_edges = e_feature.shape[0]
if num_edges < max_num_edges:
padding = torch.zeros((max_num_edges - num_edges, e_feature.shape[1]))
edge_features_padded.append(torch.cat([e_feature, padding], 0))
else:
edge_features_padded.append(e_feature)
edge_features = torch.stack(edge_features_padded, dim=0)
graph_labels = torch.stack(graph_labels_list, dim=0)
m_padded = []
for m in m_list:
num_nodes = m.shape[0]
if num_nodes < max_num_nodes:
padding = torch.zeros((max_num_nodes - num_nodes, m.shape[1]))
m_padded.append(torch.cat([m, padding], 0))
else:
m_padded.append(m)
m = torch.stack(m_padded, dim=0)
return [features, edge_indices, edge_features, graph_labels, m]
when I pass two samples of my graph dataset as:
out=collate_mydataset(tu_dataset[0:2])
it prints the out put and works fine
but when I pass the function in dataloader :
from torch_geometric import loader
torch.manual_seed(42)
batch_size=10
div_threshold = int(tu_dataset.__len__()*0.8)
train_dataset = tu_dataset[: div_threshold ]
test_dataset = tu_dataset[int(tu_dataset.__len__()*0.8):]
train_loader =loader.DataLoader(train_dataset, batch_size=batch_size, shuffle=False,collate_fn=collate_fn)
it doesn’t even call the function and says
stack expects each tensor to be equal size, but got [32, 9] at entry 0 and [15, 9] at entry 1
which in this case is saying it can’t stack node features of two graphs ( but if it calls the function this problem should be solved because in the function it handles padding and stacking )
what could be the issue?