AD Sparse KMeans with CSR format?


We are testing AD with sparse representations in PyTorch using a simple kmeans example. We are currently using the COO format as vjp/vhp throws exceptions with CSR format. We would appreciate it if we could get help with a couple of questions.

Does AD support the CSR format? And if so, are we using the wrong version of matrix multiplication (or another sparse operation)?

edit: (can only have two links, resolved second question)