Training doesn't converge when running on M1 pro GPU (MPS device)

Hi,

I’m trying to train a network model on Macbook M1 pro GPU by using the MPS device, but for some reason the training doesn’t converge, and the final training loss is 10x higher on MPS than when training on CPU.

Does anyone have any idea on what could cause this?

def train():
    device = torch.device('mps')

    epoch_number = 0

    EPOCHS = 5

    best_vloss = 1_000_000.
    model = ESFuturesPricePredictorModel(maps_multiplier=1)
    model.to(device)

    dataset = ESFuturesDataset('data', ['5m'], [130], common_transform=[my_transform, torch_transform])

    # Create data loaders for our datasets; shuffle for training, not for validation
    training_loader = torch.utils.data.DataLoader(dataset, batch_size=64*4, shuffle=True, num_workers=4)
    #validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False, num_workers=2)

    loss_fn = torch.nn.MSELoss()

    # Report split sizes
    print('Training set has {} instances'.format(len(dataset)))

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, betas=(0.9, 0.95))

    def train_one_epoch(epoch_index):
        running_loss = 0.
        last_loss = 0.

        for i, data in enumerate(training_loader):
            inputs, labels = data

            for k in inputs.keys():
                inputs[k] = inputs[k].to(device)
            for k in labels.keys():
                labels[k] = labels[k].to(device)

            optimizer.zero_grad()

            outputs = model(inputs)

            loss = loss_fn(outputs, labels['5m'])
            loss.backward()

            # Adjust learning weights
            optimizer.step()

            # Gather data and report
            running_loss += loss.item()
            if i % 100 == 99:
                last_loss = running_loss / 100 # loss per batch
                print('  batch {} loss: {}'.format(i + 1, last_loss))
                running_loss = 0.

        return last_loss

    for epoch in range(EPOCHS):
        print('EPOCH {}:'.format(epoch_number + 1))

        model.train(True)
        avg_loss = train_one_epoch(epoch_number)

        print('LOSS train {}'.format(avg_loss))

        epoch_number += 1