Hi, I am having issues with implementing a GAT model. The model runs when I dont have edge features, but once I had edge_attr into the GATConv layers I get the following error.
Traceback (most recent call last):
File "/home/alecsanc/Documents/pka_prediction/datasets/GCNN_testing/GAT_testing/train_fg.py", line 212, in <module>
pred, embedding = model(batch.x.float(),batch.edge_index,batch.edge_attr.float(), batch.batch)
File "/home/alecsanc/anaconda3/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/alecsanc/Documents/pka_prediction/datasets/GCNN_testing/GAT_testing/train_fg.py", line 163, in forward
hidden = self.initial_conv(x, edge_index,edge_attr)
File "/home/alecsanc/anaconda3/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/alecsanc/anaconda3/envs/ml/lib/python3.9/site-packages/torch_geometric/nn/conv/gat_conv.py", line 242, in forward
out = self.propagate(edge_index, x=x, alpha=alpha, edge_attr=edge_attr,
File "/home/alecsanc/anaconda3/envs/ml/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py", line 317, in propagate
out = self.message(**msg_kwargs)
File "/home/alecsanc/anaconda3/envs/ml/lib/python3.9/site-packages/torch_geometric/nn/conv/gat_conv.py", line 275, in message
assert self.lin_edge is not None
AssertionError
Any direction would be appreciated. I also included my data object as well.
Data(x=[9, 8], edge_index=[2, 18], edge_attr=[18, 5], y=[1, 1], smiles='Oc1ccc(Cl)c(Cl)c1')
CODE:
atom_num_features = 8
edge_num_features = 5
embedding_size = 8
heads = 8
class GAT(torch.nn.Module):
def __init__(self):
# Init parent
super(GAT, self).__init__()
torch.manual_seed(42)
# 2 GAT layers.
self.initial_conv = GATConv(atom_num_features, embedding_size,heads,edge_num_features)
self.conv1 = GATConv(embedding_size * heads, embedding_size)
self.conv2 = GATConv(embedding_size * heads, embedding_size)
# Output layer
self.out = Linear(embedding_size*2,1)
def forward(self, x, edge_index, edge_attr,batch_index):
# First Conv layer
hidden = self.initial_conv(x, edge_index,edge_attr)
hidden = functional.leaky_relu(hidden)
# Other Conv layers
hidden = self.conv1(hidden, edge_index)
hidden = functional.leaky_relu(self.conv1_bn(hidden))
hidden = self.conv2(hidden, edge_index)
hidden = functional.leaky_relu(self.conv2_bn(hidden))
# Global Pooling (stack different aggregations)
hidden = torch.cat([gmp(hidden, batch_index),
gap(hidden, batch_index)], dim=1)
# Apply a final (linear) classifier.
out = self.out(hidden)
return out,hidden