I am trying to use a GCN for graph classification using the following model:
class GCN(torch.nn.Module):
def __init__(self, hidden_channels, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(1, hidden_channels).to(mps_device)
self.conv2 = GCNConv(hidden_channels, hidden_channels).to(mps_device)
self.conv3 = GCNConv(hidden_channels, hidden_channels).to(mps_device)
self.lin = torch.nn.Linear(hidden_channels, num_classes).to(mps_device)
def forward(self, x, edge_index, batch):
# 1. Graph Convolutional Layers
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = self.conv3(x, edge_index)
# 2. Global Pooling
x = global_mean_pool(x, batch)
# 3. Linear Layer (Classifier)
x = self.lin(x)
return x
For training I use:
# Define your model, loss function, optimizer
model = GCN(hidden_channels=64, num_classes=3).to(mps_device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
def train():
model.train()
for data in loaderTrain:
out = model(data.x, data.edge_index, data.batch)
y = data.y.reshape(out.shape)
loss = criterion(out, y) # Compute the loss.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
optimizer.zero_grad()
def test(loader):
model.eval()
correct = 0
all = 0
for data in loader:
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim = 1)
y = data.y.reshape(out.shape).argmax(dim = 1)
correct += (pred == y).sum()
all += y.shape[0]
return correct / all
accuracy = []
val_accuracy = []
for epoch in range(1, 100 + 1):
train()
train_acc = test(loaderTrain)
test_acc = test(loaderTest)
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
accuracy.append(train_acc)
val_accuracy.append(test_acc)
If I don’t specify any device it runs fine on CPU, however, if I use .to(mps_device)
on the data, edges and model I get the error.
Here is the full error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[12], line 33
31 val_accuracy = []
32 for epoch in range(1, 100 + 1):
---> 33 train()
34 train_acc = test(loaderTrain)
35 test_acc = test(loaderTest)
Cell In[12], line 10
8 model.train()
9 for data in loaderTrain:
---> 10 out = model(data.x, data.edge_index, data.batch)
11 y = data.y.reshape(out.shape)
12 loss = criterion(out, y) # Compute the loss.
File ~/miniconda3/envs/macos/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/macos/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
Cell In[10], line 18
15 x = self.conv3(x, edge_index)
17 # 2. Global Pooling
---> 18 x = global_mean_pool(x, batch)
20 # 3. Linear Layer (Classifier)
21 x = self.lin(x)
File ~/miniconda3/envs/macos/lib/python3.11/site-packages/torch_geometric/nn/pool/glob.py:63, in global_mean_pool(x, batch, size)
61 if batch is None:
62 return x.mean(dim=dim, keepdim=x.dim() <= 2)
---> 63 return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')
File ~/miniconda3/envs/macos/lib/python3.11/site-packages/torch_geometric/utils/_scatter.py:79, in scatter(src, index, dim, dim_size, reduce)
77 if reduce == 'mean':
78 count = src.new_zeros(dim_size)
---> 79 count.scatter_add_(0, index, src.new_ones(src.size(dim)))
80 count = count.clamp(min=1)
82 index = broadcast(index, src, dim)
RuntimeError: Placeholder storage has not been allocated on MPS device!