Can't load Apple Silicon trained model to Docker cpu

I have trained a BERT model from HuggingFace transformers with an additional linear layer on top using Apple Silicon (mps). The model is saved as follows:

torch.save(self.state_dict(), model_file_path)

I have no problem loading and running this model on my mac with either the ‘cpu’ device or ‘mps’. Here is the method I am using to load it:

    def load_model_from_file(self, filename):
        state_dict = torch.load(filename, map_location=self.device)
        self.load_state_dict(state_dict)
        self.eval()
        self.to(self.device)
        return self

But when I try to run this in a Docker container in ‘cpu’ mode, I get the following error:

2023-12-08 10:04:57 ERROR - run_s - Traceback (most recent call last):
  File "/s/scripts/run_s.py", line 188, in run
    trained_model = SModel.auto_load(**all_params)
  File "/s/s/modeling/base.py", line 56, in auto_load
    return base_model.load_model_from_file(train_model_filepath)
  File "/s/sa/modeling/BERT_class.py", line 205, in load_model_from_file
    state_dict = torch.load(filename, map_location=self.device)
  File "/usr/local/lib/python3.9/site-packages/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
  File "/usr/local/lib/python3.9/site-packages/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
  File "/usr/local/lib/python3.9/site-packages/torch/serialization.py", line 1392, in persistent_load
    typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
  File "/usr/local/lib/python3.9/site-packages/torch/serialization.py", line 1366, in load_tensor
    wrap_storage=restore_location(storage, location),
  File "/usr/local/lib/python3.9/site-packages/torch/serialization.py", line 1296, in restore_location
    return default_restore_location(storage, map_location)
  File "/usr/local/lib/python3.9/site-packages/torch/serialization.py", line 381, in default_restore_location
    result = fn(storage, location)
  File "/usr/local/lib/python3.9/site-packages/torch/serialization.py", line 304, in _hpu_deserialize
    assert hpu is not None, "HPU device module is not loaded"
AssertionError: HPU device module is not loaded

How can I fix this?