I should flatten last 3 dimensions to transform 4D → 2D to be optimized via muon or am I? Could PyTorch implementation do it for me?
See the documentation
Note that Muon is an optimizer for 2D parameters of neural network hidden layers. Other parameters, such as bias, and embedding, should be optimized by a standard method such as AdamW.
I believe you will run into this error message for non 2D parameters