Simple cnn model failed to move to mps

Hi everyone, I am trying to use torch 2.1.1 to train on mps gpu. There is always runtime error says RuntimeError: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same. However, I already called model.to(“mps”). I think there should be a problem about loading the model to gpu.

What I have tried: update OS from 12.3.1 to 14.2.2. Update Python from 3.9 to 3.10. uninstall miniforge and reinstall.pip uninstall all and reinstall latest pytorch.

I am running my code in jupyter notebook. My laptop has M1 chip. next(model.parameters()).device returns exactly mps:0.

Thanks a lot

class CNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(CNN, self).__init__()
        self.c1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU())
        self.c2 = nn.Sequential(
            nn.Conv2d(64,32,3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2,2))
        self.c3 = nn.Sequential(
            nn.Conv2d(32,16,3),
            nn.BatchNorm2d(16),
            nn.ReLU())
        self.c4 = nn.Sequential(
            nn.Conv2d(16,32,3),
            nn.BatchNorm2d(32),
            nn.ReLU())
        self.c5 = nn.Sequential(
            nn.Conv2d(32,64,3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2,2))
        self.out = nn.Sequential(
            nn.Linear(64*4*4,64),
            nn.Sigmoid(),
            nn.Linear(64,num_classes))

    def forward(self, x):
        x = self.c1(x)
        x = self.c2(x)
        x = self.c3(x)
        x = self.c4(x)
        x = self.c5(x)
        x = x.view(x.size(0), -1)
        x = self.out(x)
        return x

# initialize your model
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(device) # output: mps
cnn_model = CNN(n_channels,n_classes).to(device)
print(next(cnn_model.parameters()).device) #output: mps:0

# define the BCE loss function and an optimizer you prefer
loss_f = nn.BCEWithLogitsLoss(reduction='mean')
op = torch.optim.Adam(cnn_model.parameters(),lr=lr)


def train(train_loader, val_loader, model):
    loss_values = []
    val_loss_values = []
    best_loss = 10000
    best_model = 0
    for epoch in range(EPOCH_NUM):
        model.train(True)
        loss_batch_acc = 0
        loss_val_acc = 0
        for inputs, labels in tqdm(train_loader):
            op.zero_grad()
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            outputs = outputs.softmax(dim=-1)
            labels = np.concatenate((1-labels,labels),axis=1)
            labels = torch.from_numpy(labels).to(torch.float32)
            loss=loss_f(outputs,labels)
            loss.backward()
            op.step()
            loss_batch_acc += loss.detach().numpy()*len(labels)
        loss_values.append(loss_batch_acc/train_dataset.__len__())
        model.eval()
        with torch.no_grad():
            for vinputs, vlabels in val_loader:
                vinputs = vinputs.to(device)
                vlabels = vlabels.to(device)
                voutputs = model(vinputs)
                voutputs = voutputs.softmax(dim=-1)
                vlabels = np.concatenate((1-vlabels,vlabels),axis=1)
                vlabels = torch.from_numpy(vlabels).to(torch.float32)
                vloss = loss_f(voutputs, vlabels)
                loss_val_acc += vloss.detach().numpy()*len(vlabels)
        val_loss_values.append(loss_val_acc/val_dataset.__len__())
        if loss_val_acc/val_dataset.__len__() < best_loss:
            best_loss = loss_val_acc/val_dataset.__len__()
            best_model = model
    return loss_values, val_loss_values, best_model
RuntimeError: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same