Placeholder storage has not been allocated on MPS device! MacOs M3

I am trying to use a GCN for graph classification using the following model:

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(1, hidden_channels).to(mps_device)
        self.conv2 = GCNConv(hidden_channels, hidden_channels).to(mps_device)
        self.conv3 = GCNConv(hidden_channels, hidden_channels).to(mps_device)
        self.lin = torch.nn.Linear(hidden_channels, num_classes).to(mps_device)

    def forward(self, x, edge_index, batch):
        # 1. Graph Convolutional Layers
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)

        # 2. Global Pooling
        x = global_mean_pool(x, batch)

        # 3. Linear Layer (Classifier)
        x = self.lin(x)

        return x

For training I use:

# Define your model, loss function, optimizer
model = GCN(hidden_channels=64, num_classes=3).to(mps_device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


def train():
    model.train()
    for data in loaderTrain:
        out = model(data.x, data.edge_index, data.batch)
        y = data.y.reshape(out.shape)
        loss = criterion(out, y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()

def test(loader):
    model.eval()

    correct = 0
    all = 0
    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim = 1)
        y = data.y.reshape(out.shape).argmax(dim = 1)
        correct += (pred == y).sum()
        all += y.shape[0]
    return correct / all

accuracy = []
val_accuracy = []
for epoch in range(1, 100 + 1):
    train()
    train_acc = test(loaderTrain)
    test_acc = test(loaderTest)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
    accuracy.append(train_acc)
    val_accuracy.append(test_acc)

If I don’t specify any device it runs fine on CPU, however, if I use .to(mps_device) on the data, edges and model I get the error.

Here is the full error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[12], line 33
     31 val_accuracy = []
     32 for epoch in range(1, 100 + 1):
---> 33     train()
     34     train_acc = test(loaderTrain)
     35     test_acc = test(loaderTest)

Cell In[12], line 10
      8 model.train()
      9 for data in loaderTrain:
---> 10     out = model(data.x, data.edge_index, data.batch)
     11     y = data.y.reshape(out.shape)
     12     loss = criterion(out, y)  # Compute the loss.

File ~/miniconda3/envs/macos/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/macos/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

Cell In[10], line 18
     15 x = self.conv3(x, edge_index)
     17 # 2. Global Pooling
---> 18 x = global_mean_pool(x, batch)
     20 # 3. Linear Layer (Classifier)
     21 x = self.lin(x)

File ~/miniconda3/envs/macos/lib/python3.11/site-packages/torch_geometric/nn/pool/glob.py:63, in global_mean_pool(x, batch, size)
     61 if batch is None:
     62     return x.mean(dim=dim, keepdim=x.dim() <= 2)
---> 63 return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')

File ~/miniconda3/envs/macos/lib/python3.11/site-packages/torch_geometric/utils/_scatter.py:79, in scatter(src, index, dim, dim_size, reduce)
     77 if reduce == 'mean':
     78     count = src.new_zeros(dim_size)
---> 79     count.scatter_add_(0, index, src.new_ones(src.size(dim)))
     80     count = count.clamp(min=1)
     82     index = broadcast(index, src, dim)

RuntimeError: Placeholder storage has not been allocated on MPS device!