Pytoch geometric dimension mismatch input and edge indices

Hello

I am trying to implement pytorch geometric and I have dimension mismatch to run the code. I am taking a batch of 32 texts. I take one text at a time and creates its adjacency matric and edge indices. Then i save all the 32 into one big batch of 32. But GCN gives error. Can someone please see the code below and correct me. thanks

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from transformers import DistilBertTokenizer

from transformers import DistilBertForSequenceClassification

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import torch.nn.functional as F
from torch_geometric.nn import GCNConv



# Load train data
train_data = pd.read_csv('https://raw.githubusercontent.com/salarMokhtariL/Facke-News-Detection/main/Dataset/train.csv')

# Load test data
test_data = pd.read_csv('https://raw.githubusercontent.com/salarMokhtariL/Facke-News-Detection/main/Dataset/test.csv')


train_data.dropna(inplace=True)



# %% prepare the data

''' This class takes in the data, tokenizes it using the DistilBertTokenizer from the transformers library,
 and returns the input IDs, attention masks, and labels.'''


    

class FakeNewsDataset2(Dataset):
    def __init__(self, data, max_len=128):
        self.data = data
        self.max_len = max_len
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        text = self.data.iloc[index]['text']
        label = self.data.iloc[index]['label']
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_token_type_ids=True,
            return_attention_mask=True,
            return_tensors='pt',
        )


        return text, inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0), torch.tensor(label, dtype=torch.long)
       
# %%split the data into training and validation sets

train_data, val_data = train_test_split(train_data, test_size=0.2,
                                        random_state=42)

# %%

# Create PyTorch data loaders for the training, validation, and test sets:

train_dataset = FakeNewsDataset2(train_data)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


val_dataset = FakeNewsDataset2(val_data)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


test_dataset = FakeNewsDataset2(test_data)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)



    

# %%  train gcn

# from torch_geometric.utils import to_dense_adj
from torch_geometric.data import Data, Batch




def create_word_cooccurrence_adjacency(sentences):
        words=sentences
        words = sorted(words)  # not Unique words



        word_to_idx = {word: i for i, word in enumerate(words)}
        n_words = len(words)

        # Initialize adjacency matrix with zeros
        adj_matrix = torch.zeros((n_words, n_words), dtype=torch.float)

        # Establish connections
        for i, word1 in enumerate(words):
            for j, word2 in enumerate(words):
                if word1[0] == word2[0]:
                    adj_matrix[i, j] = 1.0  # Set connection

        return adj_matrix

def adjacency_to_edge_index(adjacency_matrix):
    """
    Converts an adjacency matrix to edge indices.

    Args:
        adjacency_matrix (torch.Tensor): A 2D tensor representing the adjacency matrix.

    Returns:
        torch.Tensor: A 2xN tensor representing the edge indices, where N is the number of edges.
    """
    row, col = adjacency_matrix.nonzero(as_tuple=False).t()
    edge_index = torch.stack([row, col], dim=0)
    return edge_index


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        print (f' {x.shape}')
        return x
    





def train_gcn_epoch(model, optimizer, criterion, train_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.train()
    train_loss = 0
    train_acc = 0

    for text,input_ids, attention_mask, labels in tqdm(train_loader, desc='Training'):
        
        input_ids2=input_ids.to(dtype=torch.float32).to(device).clone()
        labels2=labels.to(dtype=torch.float32).to(device).clone()
        
        edge_index = []
        edge_index_batch=[]
        
        tokenizer=DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

   

        adj_matrix_batch=[]# 

        for j in range(len(text)):# for a batch of 32(texts) it takes  one text at a time
            
            
            inputs = tokenizer.encode(
                text[j],
                add_special_tokens=True,
                max_length=130,
                padding='max_length',
                truncation=True)
            
            tokens = tokenizer.convert_ids_to_tokens(inputs)
            tokens=tokens[1:129]
            print(f'tokens {len(tokens)}')

            

            adj_matrix=create_word_cooccurrence_adjacency(tokens)

            print(f'adj {adj_matrix}')
            # print(f'input_ids {input_ids2.shape}')

            
            adj_matrix_batch.append(adj_matrix)
            edge_index=adjacency_to_edge_index(adj_matrix)
            edge_index_batch.append(edge_index)
            


   
        data_all_list=[]
        for i in range(len(text)):
            
            data= Data(x=torch.unsqueeze(input_ids2[i],0), edge_index=edge_index_batch[i],y=labels2[i])
            data_all_list.append(data)
            
        
        batch = Batch.from_data_list(data_all_list)


        
        # print(f'x {batch.x.shape}')
        # print(f'edge {batch.edge_index.shape}')
        # print(f'label {labels.shape}')
        # print(f' num_nodes {batch.num_nodes}')
        # print(f' edge index max{batch.edge_index.max()}')

        outputs = model(torch.tensor(batch.x).to(device), torch.tensor(batch.edge_index).to(torch.int64).to(device))



 
        loss = criterion(outputs, labels2.to(torch.int64))
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += (outputs.argmax(1) == labels.to(device)).sum().item()

    train_loss /= len(train_loader)
    train_acc /= len(train_loader.dataset)

    return train_loss, train_acc



# %% evaluate GCNN

model = GCN(in_channels=128, hidden_channels=64, out_channels=2).to(device)

optimizer = optim.Adam(model.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()

best_val_acc = 0
# Training loop
num_epochs = 25

for epoch in range(num_epochs):
    train_loss, train_acc = train_gcn_epoch(model, optimizer, criterion, train_loader)

**Homework Ai Helper:

This error usually means your edge_index references node indices that do not exist in x (your node features). Make sure:

edge_index.max x.size(0) Node featuresxandedge_index are correctly sized and aligned.

@Joy Thanks for the response, but I am already aware of it and thus the title “dimension mismatch” . I sent this script(which is not too long), if someone can explain how to go about it right and I be pleased to get educated along the way

hiii

the main issue is that you’re passing token ids to the GCN layer, GCN layer expect the shape of X to be num_features, emb_dim, where in your code it’s passing a tensor shaped (128, 1)

to fix your code with absolute minimal changes we would

  • import and use the bert_model to get the embedding of tokens (token_ids => token_embeddings)
  • use global_max_pool after the GCN layer(s) to get a graph representation (aggregating nodes representations)
  • keep input_ids as integers (no need to convert to floats)
  • I noticed that you didn’t zero out grads after each iteration, thus accumulating the gradients over all training, so I added that
  • set the GCN layer in_channels to 768 instead of 128, which is the emb_dim of bert

and the fixed code with absolute minimal changes would be


import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from transformers import DistilBertTokenizer

from transformers import DistilBertForSequenceClassification

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import torch.nn.functional as F
from torch_geometric.nn import GCNConv



# Load train data
train_data = pd.read_csv('https://raw.githubusercontent.com/salarMokhtariL/Facke-News-Detection/main/Dataset/train.csv')

# Load test data
test_data = pd.read_csv('https://raw.githubusercontent.com/salarMokhtariL/Facke-News-Detection/main/Dataset/test.csv')


train_data.dropna(inplace=True)



# %% prepare the data

''' This class takes in the data, tokenizes it using the DistilBertTokenizer from the transformers library,
 and returns the input IDs, attention masks, and labels.'''


    

class FakeNewsDataset2(Dataset):
    def __init__(self, data, max_len=128):
        self.data = data
        self.max_len = max_len
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        text = self.data.iloc[index]['text']
        label = self.data.iloc[index]['label']
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_token_type_ids=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return text, inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0), torch.tensor(label, dtype=torch.long)
       
# %%split the data into training and validation sets

train_data, val_data = train_test_split(train_data, test_size=0.2,
                                        random_state=42)

# %%

# Create PyTorch data loaders for the training, validation, and test sets:

train_dataset = FakeNewsDataset2(train_data)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


val_dataset = FakeNewsDataset2(val_data)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


test_dataset = FakeNewsDataset2(test_data)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)



    

# %%  train gcn

# from torch_geometric.utils import to_dense_adj
from torch_geometric.data import Data, Batch




def create_word_cooccurrence_adjacency(sentences):
        words=sentences
        words = sorted(words)  # not Unique words

        word_to_idx = {word: i for i, word in enumerate(words)}
        n_words = len(words)

        # Initialize adjacency matrix with zeros
        adj_matrix = torch.zeros((n_words, n_words), dtype=torch.float)

        # Establish connections
        for i, word1 in enumerate(words):
            for j, word2 in enumerate(words):
                if word1[0] == word2[0]:
                    adj_matrix[i, j] = 1.0  # Set connection

        return adj_matrix

def adjacency_to_edge_index(adjacency_matrix):
    """
    Converts an adjacency matrix to edge indices.

    Args:
        adjacency_matrix (torch.Tensor): A 2D tensor representing the adjacency matrix.

    Returns:
        torch.Tensor: A 2xN tensor representing the edge indices, where N is the number of edges.
    """
    row, col = adjacency_matrix.nonzero(as_tuple=False).t()
    edge_index = torch.stack([row, col], dim=0)
    return edge_index


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        # print (f' {x.shape}')
        return x
    





def train_gcn_epoch(model, optimizer, criterion, train_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.train()
    train_loss = 0
    train_acc = 0

    for text,input_ids, attention_mask, labels in tqdm(train_loader, desc='Training'):
        
        input_ids2=input_ids.to(device).clone()
        labels2=labels.to(dtype=torch.float32).to(device).clone()
        attention_mask2=attention_mask.to(device).clone()
        
        
        edge_index = []
        edge_index_batch=[]
        
        tokenizer=DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

   

        adj_matrix_batch=[]# 

        for j in range(len(text)):# for a batch of 32(texts) it takes  one text at a time
            
            
            inputs = tokenizer.encode(
                text[j],
                add_special_tokens=True,
                max_length=130,
                padding='max_length',
                truncation=True)
            
            tokens = tokenizer.convert_ids_to_tokens(inputs)
            tokens=tokens[1:129]
            # print(f'tokens {len(tokens)}')

            

            adj_matrix=create_word_cooccurrence_adjacency(tokens)

            # print(f'adj {adj_matrix}')
            # print(f'input_ids {input_ids2.shape}')

            
            adj_matrix_batch.append(adj_matrix)
            edge_index=adjacency_to_edge_index(adj_matrix)
            edge_index_batch.append(edge_index)
            
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids2, attention_mask=attention_mask2)
            embeddings = outputs.last_hidden_state  # [batch_size, seq_len=128, 768]


   
        data_all_list=[]
        for i in range(len(text)):
            
            data= Data(x=embeddings[i], edge_index=edge_index_batch[i],y=labels2[i])
            data_all_list.append(data)
            
        
        batch = Batch.from_data_list(data_all_list)


        
        # print(f'x {batch.x.shape}')
        # print(f'edge {batch.edge_index.shape}')
        # print(f'label {labels.shape}')
        # print(f' num_nodes {batch.num_nodes}')
        # print(f' edge index max{batch.edge_index.max()}')

        outputs = model(torch.tensor(batch.x).to(device), torch.tensor(batch.edge_index).to(torch.int64).to(device))
        outputs = global_mean_pool(outputs, batch.batch)


 
        loss = criterion(outputs, labels2.to(torch.int64))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        train_loss += loss.item()
        train_acc += (outputs.argmax(1) == labels.to(device)).sum().item()

    train_loss /= len(train_loader)
    train_acc /= len(train_loader.dataset)

    return train_loss, train_acc

from transformers import DistilBertTokenizer, DistilBertModel
bert_model = DistilBertModel.from_pretrained('distilbert-base-uncased')
bert_model.eval()  # set to eval mode to disable dropout etc.
bert_model.to(device)  # move to GPU if available
# %% evaluate GCNN

from torch_geometric.nn import global_mean_pool
model = GCN(in_channels=768, hidden_channels=64, out_channels=2).to(device)

optimizer = optim.Adam(model.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()

best_val_acc = 0
# Training loop
num_epochs = 25

for epoch in range(num_epochs):
    train_loss, train_acc = train_gcn_epoch(model, optimizer, criterion, train_loader)

a slightly better managed version would be something like


import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, Batch

from transformers import DistilBertTokenizer
from transformers import DistilBertTokenizer, DistilBertModel

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

train_data = pd.read_csv('https://raw.githubusercontent.com/salarMokhtariL/Facke-News-Detection/main/Dataset/train.csv')
test_data = pd.read_csv('https://raw.githubusercontent.com/salarMokhtariL/Facke-News-Detection/main/Dataset/test.csv')
train_data.dropna(inplace=True)


class FakeNewsDataset2(Dataset):
    def __init__(self, data, max_len=128):
        self.data = data
        self.max_len = max_len
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        text = self.data.iloc[index]['text']
        label = self.data.iloc[index]['label']
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_token_type_ids=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return text, inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0), torch.tensor(label, dtype=torch.long)
       

train_data, val_data = train_test_split(train_data, test_size=0.2,random_state=42)

train_dataset = FakeNewsDataset2(train_data)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = FakeNewsDataset2(val_data)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

test_dataset = FakeNewsDataset2(test_data)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert_model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
bert_model.eval()


def create_word_cooccurrence_adjacency(tokens):
    unique_tokens = sorted(set(tokens)) # unique tokens for smaller graphs, your GPU will thank u
    n = len(unique_tokens)
    token_to_idx = {tok: i for i, tok in enumerate(unique_tokens)}

    adj = torch.zeros((n, n), dtype=torch.float)

    for i, t1 in enumerate(unique_tokens):
        for j, t2 in enumerate(unique_tokens):
            if t1[0] == t2[0]:
                adj[i, j] = 1.0
    return adj, token_to_idx


def adjacency_to_edge_index(adj):
    row, col = adj.nonzero(as_tuple=True)
    return torch.stack([row, col], dim=0)



class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        # print (f' {x.shape}')
        return x
    


def train_gcn_epoch(model, optimizer, criterion, train_loader):
    model.train()
    train_loss = 0
    train_acc = 0

    for text_batch, input_ids, attention_mask, labels in tqdm(train_loader, desc='Training'):        
        input_ids = input_ids.to(device) # no need to clone
        labels = labels.to(device)
        attention_mask = attention_mask.to(device)
           
        data_list=[]
        for i, text in enumerate(text_batch):
            inputs_ids_for_adj = tokenizer.encode(
                text,
                add_special_tokens=True,
                max_length=128,
                truncation=True,
                padding='max_length'
            )
            tokens = tokenizer.convert_ids_to_tokens(inputs_ids_for_adj)[1:129]

            adj, token_to_idx = create_word_cooccurrence_adjacency(tokens)
            edge_index = adjacency_to_edge_index(adj).to(device)

            unique_tokens = sorted(set(tokens))
            data_list.append({'edge_index': edge_index, 'token_idx_map': token_to_idx, 'unique_tokens': unique_tokens})


        # computing the embeddings from ids, no need to keep track of grads here
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state  # (bs, seq_len, 768) which is what we need to pass to the GCN layer

        pyg_data_list = []
        for i, d in enumerate(data_list):
            unique_tokens = d['unique_tokens']
            input_ids_list = input_ids[i].tolist()
            node_embeddings = []
            for tok in unique_tokens:
                tok_id = tokenizer.convert_tokens_to_ids(tok)
                if tok_id in input_ids_list:
                    idx_in_seq = input_ids_list.index(tok_id)
                else:
                    idx_in_seq = 0
                node_embeddings.append(embeddings[i, idx_in_seq, :])
            x = torch.stack(node_embeddings, dim=0)
            pyg_data_list.append(Data(x=x, edge_index=d['edge_index'], y=labels[i].unsqueeze(0)))

        batch = Batch.from_data_list(pyg_data_list).to(device)
        
        outputs = model(torch.tensor(batch.x).to(device), torch.tensor(batch.edge_index).to(torch.int64).to(device))
        outputs = global_mean_pool(outputs, batch.batch)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        train_loss += loss.item()
        train_acc += (outputs.argmax(1) == labels.to(device)).sum().item()

    train_loss /= len(train_loader)
    train_acc /= len(train_loader.dataset)

    return train_loss, train_acc


model = GCN(in_channels=768, hidden_channels=64, out_channels=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

best_val_acc = 0
num_epochs = 25

for epoch in range(num_epochs):
    train_loss, train_acc = train_gcn_epoch(model, optimizer, criterion, train_loader)


note: I didn’t really focude on the logic / aim of the code, just fixed the obvious things, you can easily improve it further

hope that helped

@Dhia-naouali Thanks for the detail answer and spending time on this script. I try to understand to educate myself. You have nicely explained. I ran the two script. The second better one, has unique tokens and the number of nodes have decreased in size but number of edge indices is increased. Whereas in first approach, the number of nodes are slightly more(fixed size, since non-unique) , but number edge indices is smaller. May I know why this is happening and what the second approach is better?

The other conceptual question is that why goes GCN takes embedding and not input ids like in standard transformers.?

My questions clearly indicate me being novice. It would be nice to learn more as my mistakes originate from these lack of concepts.
Thankfully.

for the first part,
it really depends on the task in hand and strategy you’re willing to adapt (didn’t really focus in that in the correction as much as to speed the run, (my 2GB GPU could never XD)
i.e you can eather create a lookup table / matrix of unique tokens mapping each token with it’s “future neighbors in the graph” then for each token you would look for it’s corresponding row, or computing that for all (non unique tokens) and directly using that matrix later on, as I mentioned I didn’t really focus on that

for the second part, which is the original issue your script had
transformer based models / language models basically consist of an embedding layer + a series of transformer blocks
as you input token ids, the embedding layer would retrieve the token embeddings then pass them to the rest of the blocks (i.e [token_ids ==embedding layer==> token_embeddings ] => transformer block)
each model (bert, llama, …) have it’s own embedding table

where for a GNN, it expect these embeddings just like the transformer blocks do
a more intuitive way to look at it is:
each token (:word or chunck of characters) is represented by a state vector, you can think of that as multiple scores; how positive this word is, how dark is it, how affirmative, …,

if you’re just starting I would recommend stanford lectures introduing the transformers architecture, + this paper comparing Transformers with GNNs, enjoy !