Mac M1 mps error mps.scatter_nd when iterating across architectures

I’m using nanoGPT but this post is more about mps than nanoGPT.

Hi all,

I’m on Apple M1 and used the nanoGPT train.py to make a train_gpt(…) function, so I can sweep across architectures (varying number of layers, heads and embeddings).

When only using the cpu, I get no errors. But when using mps, I get this error:

loc("edb"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/97f6331a-ba75-11ed-a4bc-863efbbaf80d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":49:0)): error: 'mps.scatter_nd' op invalid input tensor shape: updates tensor shape and data tensor shape must match along inner dimensions

/AppleInternal/Library/BuildRoots/97f6331a-ba75-11ed-a4bc-863efbbaf80d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1710: failed assertion `Error: MLIR pass manager failed'

If I use mps option and keep the architecture fixed (e.g. {'n_layer': 4, 'n_head' : 4, 'n_embd' : 128}), then mps works without error. However, when I loop across varying architectures I get the error above ^^.

The issue is not with compatible n_layer, n_head and n_embd dimensions or memory issues. I’ve tested that the architectures in which the error occurs work fine if ran alone (rather than in a loop after a different architecture). The general rules I use for generating architectures are: 1. n_head is a multiple of n_layer, and 2. n_embd is divisible by n_head.

So it smells like the GPU cache is not getting cleared after each train_gpt(…) iteration and there is some stale dimensions being expected in operations.

Is there a way to clear the GPU cache in mps framework? I’ve tried throwing a mps.empty_cache() at the end of each iteration (by iteration, I mean a round of calling train_gpt(…) for a specific architecture, not an iteration in each training loop). But that doesn’t work.

Thanks in advance!

1 Like