Would it be possible to integrate a Distributed Matrix-Matrix Multiplication Algorithm like COSMA into PyTorch. COSMA is communication-optimal, gpu-accelerated algorithm for matmul, that is already used in some HPC applications with great performance results!
I am one of the main developers of COSMA and would be happy to integrate it into PyTorch.
Let me know your thoughts on that!