RuntimeError: Tensor for argument weight is on cpu but expected on mps

I’m getting this error:

2.4.0
mps
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Linear: 1-1                            [-1, 8]                   24
├─ReLU: 1-2                              [-1, 8]                   --
├─Linear: 1-3                            [-1, 1]                   9
==========================================================================================
Total params: 33
Trainable params: 33
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================
Traceback (most recent call last):
  File "/Users/amitthakur/PycharmProjects/computer-vision/ann/my_dataset.py", line 56, in <module>
    main()
  File "/Users/amitthakur/PycharmProjects/computer-vision/ann/my_dataset.py", line 47, in main
    loss_value = loss_func(model(ix), iy)
                           ^^^^^^^^^
  File "/Users/amitthakur/anaconda3/envs/computer-vision/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/amitthakur/anaconda3/envs/computer-vision/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/amitthakur/anaconda3/envs/computer-vision/lib/python3.11/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/Users/amitthakur/anaconda3/envs/computer-vision/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/amitthakur/anaconda3/envs/computer-vision/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/amitthakur/anaconda3/envs/computer-vision/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Tensor for argument weight is on cpu but expected on mps
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from torchsummary import summary
from torch.optim import SGD
import time

print(torch.__version__)

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(device)

class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x).float().to(device)
        self.y = torch.tensor(y).float().to(device)

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return len(self.x)


def main():

    x = [[1, 2], [3, 4], [5, 6], [7, 8]]
    y = [[3], [7], [11], [15]]
    ds = MyDataset(x, y)
    dl = DataLoader(ds, batch_size=2, shuffle=True)

    model = nn.Sequential(
        nn.Linear(2, 8),
        nn.ReLU(),
        nn.Linear(8, 1),
    ).to(device)

    summary(model, torch.zeros(1, 2).to(device))

    loss_func = nn.MSELoss()
    opt = SGD(model.parameters(), lr=0.001)
    loss_history = []
    start = time.time()
    for _ in range(50):
        for ix, iy in dl:
            opt.zero_grad()
            loss_value = loss_func(model(ix), iy)
            loss_value.backward()
            opt.step()
            loss_history.append(loss_value.item())
    end = time.time()
    print(end - start)


if __name__ == '__main__':
    main()

Never mind. I got the root cause:

summary(model, torch.zeros(1, 2), device=device)