MPS on m2 max working poorly

I have a macbook pro m2 max and attempted to run my first training loop on device = ‘mps’.

I’m running a simple matrix factorization model for a collaborative filtering problem R = U*V.t, where U and V share a latent factor dimension. I’ve found that my kernel dies every time I try and run the training loop except on the most trivial models (latent factor dim = 1) and batch sizes (n = 1). Even with these trivial dimensions, the model seems to run quite slowly, though it does progress through the epochs.

The training loop runs as expected on the CPU.

I’m somewhat new to pytorch and training on GPUs, so I may be doing something wrong.

Training loop code below:

for epoch in range(EPOCH_NUM):
    model.train()
    train_loss = 0
    train_mae = 0
    for user_item_idx, rating in train_loader:
        user, item = user_item_idx[:, 0].to(device), user_item_idx[:, 1].to(device)
        rating = rating.to(device)
        optimizer.zero_grad()
        prediction = model(user, item)
        batch_mse = loss(prediction, rating)
        batch_loss = batch_mse
        batch_loss.backward()
        optimizer.step()
        train_loss += batch_mse.item()
        train_mae += torch.abs(prediction - rating).mean().item()
    train_loss /= len(train_loader)
    train_mae /= len(train_loader)


    model.eval()
    val_loss = 0
    val_mae = 0
    for user_item_idx, rating in val_loader:
        user, item = user_item_idx[:, 0].to(device), user_item_idx[:, 1].to(device)
        rating = rating.to(device)
        prediction = model(user, item)
        batch_loss = loss(prediction, rating)
        val_loss += batch_loss.item()
        val_mae += torch.abs(prediction - rating).mean().item()
    val_loss /= len(val_loader)
    val_mae /= len(val_loader)

    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1} | Train loss: {train_loss:.4f} | Train MAE: {train_mae:.4f} | Val loss: {val_loss:.4f} | Val MAE: {val_mae:.4f}")

    prev_val_loss = [np.inf]
    # early stopping
    if epoch > 0 and val_loss > (min(prev_val_loss) * 1.05):
        break
    prev_val_loss.append(val_loss)

What error are you getting exactly? Make sure to monitor the macOS activity monitor it might give you a hint as to what’s going