Can’t get this GCN to work. Any help is appreciated.
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(dataset.num_features, 16, cached=True)
self.conv2 = GCNConv(16, int(dataset.num_classes), cached=True)
# self.conv1 = ChebConv(data.num_features, 16, K=2)
# self.conv2 = ChebConv(16, data.num_features, K=2)
self.reg_params = self.conv1.parameters()
self.non_reg_params = self.conv2.parameters()
def forward(self):
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
x = F.relu(self.conv1(x, edge_index, edge_weight))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index, edge_weight)
return F.log_softmax(x, dim=1)
def train():
model.train()
optimizer.zero_grad()
F.nll_loss(model()[data.train_mask], data.y[data.train_mask].long()).backward()
optimizer.step()
Error:
Traceback (most recent call last):
File "src/meta_gnn.py", line 395, in <module>
train()
File "src/meta_gnn.py", line 307, in train
F.nll_loss(model()[data.train_mask], data.y[data.train_mask].long()).backward()
File "/home/ppandey/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "src/meta_gnn.py", line 299, in forward
x = F.relu(self.conv1(x, edge_index, edge_weight))
File "/home/ppandey/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/ppandey/.local/lib/python3.6/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 87, in forward
x = torch.matmul(x, self.weight)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'mat2' in call to _th_mm