Sparse matrix computation in pytorch

Here’s a gist explaining this

a clear solution on how I can achieve this in pytorch