GAT edge features Assertion Error

Hi, I am having issues with implementing a GAT model. The model runs when I dont have edge features, but once I had edge_attr into the GATConv layers I get the following error.

Traceback (most recent call last):
  File "/home/alecsanc/Documents/pka_prediction/datasets/GCNN_testing/GAT_testing/", line 212, in <module>
    pred, embedding = model(batch.x.float(),batch.edge_index,batch.edge_attr.float(), batch.batch)
  File "/home/alecsanc/anaconda3/envs/ml/lib/python3.9/site-packages/torch/nn/modules/", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/alecsanc/Documents/pka_prediction/datasets/GCNN_testing/GAT_testing/", line 163, in forward
    hidden = self.initial_conv(x, edge_index,edge_attr)
  File "/home/alecsanc/anaconda3/envs/ml/lib/python3.9/site-packages/torch/nn/modules/", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/alecsanc/anaconda3/envs/ml/lib/python3.9/site-packages/torch_geometric/nn/conv/", line 242, in forward
    out = self.propagate(edge_index, x=x, alpha=alpha, edge_attr=edge_attr,
  File "/home/alecsanc/anaconda3/envs/ml/lib/python3.9/site-packages/torch_geometric/nn/conv/", line 317, in propagate
    out = self.message(**msg_kwargs)
  File "/home/alecsanc/anaconda3/envs/ml/lib/python3.9/site-packages/torch_geometric/nn/conv/", line 275, in message
    assert self.lin_edge is not None

Any direction would be appreciated. I also included my data object as well.

Data(x=[9, 8], edge_index=[2, 18], edge_attr=[18, 5], y=[1, 1], smiles='Oc1ccc(Cl)c(Cl)c1')


atom_num_features = 8
edge_num_features = 5
embedding_size = 8
heads = 8
class GAT(torch.nn.Module):
    def __init__(self):
        # Init parent
        super(GAT, self).__init__()

        # 2 GAT layers. 
        self.initial_conv = GATConv(atom_num_features, embedding_size,heads,edge_num_features)
        self.conv1 = GATConv(embedding_size * heads, embedding_size)
        self.conv2 = GATConv(embedding_size * heads, embedding_size)
        # Output layer
        self.out = Linear(embedding_size*2,1)

    def forward(self, x, edge_index, edge_attr,batch_index):
        # First Conv layer
        hidden = self.initial_conv(x, edge_index,edge_attr)
        hidden = functional.leaky_relu(hidden)

        # Other Conv layers
        hidden = self.conv1(hidden, edge_index)
        hidden = functional.leaky_relu(self.conv1_bn(hidden))
        hidden = self.conv2(hidden, edge_index)
        hidden = functional.leaky_relu(self.conv2_bn(hidden))
        # Global Pooling (stack different aggregations)
        hidden =[gmp(hidden, batch_index),
                            gap(hidden, batch_index)], dim=1)

# Apply a final (linear) classifier.
        out = self.out(hidden)
        return out,hidden

You need to pass the “edge_dim=11” argument to the GATConv layer