Code didn't speed up as expected when using `mps`

I’m really excited to try out the latest pytorch build (1.12.0.dev20220518) for the m1 gpu support, but on my device (M1 max, 64GB, 16-inch MBP), the training time per epoch on cpu is ~9s, but after switching to mps, the performance drops significantly to ~17s. Is that something we should expect, or did I just mess something up?

BTW, there were failed assertions that may relate to mps as well, the traceback is as follows

/AppleInternal/Library/BuildRoots/560148d7-a559-11ec-8c96-4add460b61a6/Library/Caches/ failed assertion `unsupported datatype for constant'

/opt/homebrew/Cellar/python@3.10/3.10.4/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/ UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

What were the data types of the tensors you were using? Could you show us the script? I have a good idea of which data types MPS and MPSGraph supports. Also, I think I have seen some peculiarities with MPSGraph’s internal MLIR compiler, which might cause a crash.

@albanD could you check out the memory management issue in the above comment’s log? I have previously seen apps crash when you don’t balance a DispatchSemaphore before releasing it. Maybe that’s an iOS-specific crash, but it should be standard practice to balance semaphore objects when using Foundation.

the training time per epoch on cpu is ~9s, but after switching to mps , the performance drops significantly to ~17s.

If you happen to be using all CPU cores on the M1 Max in cpu mode, then you have 2.0 TFLOPS of processing power (256 GFLOPS per core) for matrix multiplication. On GPU, you have a maximum of 10.4 TFLOPS, although 80% of that is used. If a certain tensor shape works better on the CPU than the GPU, then the slowdown you experienced might be possible.

Thanks for the report.

  • For the datatype crash. You might want to make sure you’re not using float64 numbers next to the place where the crash happens as these are not supported for Tensors on mps.
  • For the slowdown, that is not expected for sure (but can happen depending on the workload as GPUs require large relatively large tasks to see speedups). If you have a simple repro for us to run locally, that would be very helpful!

For a more extensive list of which data types do and don’t run:

  • Avoid Float64 on all Apple devices. Even if the hardware supports Double physically (AMD or Intel), the Metal API doesn’t let you access it.
  • Avoid BFloat16. That is natively supported by the latest Nvidia GPUs, but not supported in Metal. Also don’t try to use TF18/TF32 or Int4.
  • All standard integer types (UInt8, UInt16, UInt32, UInt64) and their signed counterparts work natively on Apple devices. Not exactly 8-bit integers, which are cast to 16-bit integers before being stored into registers, but those aren’t going to harm performance. Yes, they run 64-bit integers on the Apple GPU and not 64-bit floats. Metal allows you to use 64-bit integers in shaders on AMD and Intel, but the arithmetic there might just happen through emulation (slow). I think that’s where I experienced the crash in MPSGraph previously - trying to run an operation on UInt64 on my Intel Mac mini.

For anything else regarding compatibility of data types, the Metal feature set tables might be useful:


I remember the crash with MPSGraph more precisely. I was messing around with transcendental functions while prototyping an S4TF Metal backend, and had a configuration where I passed a UInt64 MPSGraphTensor into a function that normally runs with floating point in hardware. They would have to convert the integer to a float before actually running the operation. It worked for everything up to 32-bit integers, but I made a special configuration using 64-bit integers that broke the MLIR compiler. The error message was something like “pow_u64 not found”. You may need to watch out for a bug like this as well as Float64 crashes.

Here are all the code that I use to test, download and just run should work. I think that I’m indeed using the Float64 datatype, which is the default I think? I’ll try using another datatype later.

Sorry that I can only put up to 2 links in a post. so here is another half of the code