Map_location not working with mps

Given a device = torch.device("mps"), I can’t use functions that have a map_location parameter (while this is working fine with “cpu” or “cuda:0” devices).

model = torch.jit.load(buffer, map_location=device) will give the following error:

   model = torch.jit.load(buffer, map_location=device)
  File "/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/jit/_serialization.py", line 164, in load
    cpp_module = torch._C.import_ir_module_from_buffer(
RuntimeError: supported devices include CPU and CUDA, however got MPS

and model = torch.package.PackageImporter(buffer).load_pickle('model', 'model.pkl', map_location=device) will give the following error:

    model = torch.package.PackageImporter(buffer).load_pickle('model', 'model.pkl', map_location=device)
  File "/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/package/package_importer.py", line 256, in load_pickle
    result = unpickler.load()
  File "/Users/divide/TorchStudio/python/lib/python3.9/pickle.py", line 1212, in load
    dispatch[key[0]](self)
  File "/Users/divide/TorchStudio/python/lib/python3.9/pickle.py", line 1253, in load_binpersid
    self.append(self.persistent_load(pid))
  File "/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/package/package_importer.py", line 213, in persistent_load
    load_tensor(
  File "/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/package/package_importer.py", line 201, in load_tensor
    loaded_storages[key] = restore_location(storage, location)
  File "/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/serialization.py", line 973, in restore_location
    return default_restore_location(storage, str(map_location))
  File "/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/serialization.py", line 178, in default_restore_location
    raise RuntimeError("don't know how to restore data location of "
RuntimeError: don't know how to restore data location of torch.storage._TypedStorage (tagged with mps)

Claims to be fixed in torch.load() fails on MPS backend ("don't know how to restore data location") · Issue #79384 · pytorch/pytorch · GitHub but I get the same problem

The serialization code was added back in and hopefully works @Robin_Lobel with your example. Can you please check the latest nightly.