Currently, matrix multiplication seems to require matching dtypes of its inputs:
In [6]: torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Double
Is it a deliberate decision (to limit the number of kernels or some such), or is it an oversight which can be fixed in the future?
The context for the question is the array API compatibility work in SciPy, ENH: signal.vectorstrength: add array API standard support by ev-br · Pull Request #22008 · scipy/scipy · GitHub. A workaround is not particularly difficult though; whether a workaround needs to be permanent or temporary, would be great to have a canonic answer for, hence this post :-).
Cheers,
Evgeni