Apple M1 silicon: TypeError: Cannot convert a MPS Tensor to float64 dtype

I attempt to train SegFormer model which I followed the Roboflow tutorial on my local machine M1 pro: How To Train SegFormer on a Custom Dataset for Computer Vision - YouTube

I searched for the error on Google but I have no clue how to solve this problem which seems like on-going problem/bug.

Error message:

- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([3, 256, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate:
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                             | Params
0 | model | SegformerForSemanticSegmentation | 3.7 M 
3.7 M     Trainable params
0         Non-trainable params
3.7 M     Total params
14.860    Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
Sanity Checking DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.57it/s]invalid value encountered in divide
Traceback (most recent call last):
  File "/Users/tonggihkang/", line 296, in <module>
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 696, in fit
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 650, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1166, in _run
    results = self._run_stage()
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1252, in _run_stage
    return self._run_train()
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1274, in _run_train
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1343, in _run_sanity_check
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/loops/", line 207, in run
    output = self.on_run_end()
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/", line 183, in on_run_end
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/", line 293, in _evaluation_epoch_end
    self.trainer._call_lightning_module_hook(hook_name, output_or_outputs)
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1550, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/Users/tonggihkang/", line 191, in validation_epoch_end
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/core/", line 451, in log
    value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name)
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/utilities/", line 99, in apply_to_collection
    return function(data, *args, **kwargs)
  File "/Users/tonggihkang/ML/lib/python3.10/site-packages/pytorch_lightning/core/", line 587, in __to_tensor
    else torch.tensor(value, device=self.device)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.