Min Operation On a Sparse Tensor / Variable Sizes Across Second Dimension

Here’s the situation. I have a very large 2d tensor which is mostly sparse (like, <0.1% density), everything else is filled with nan. I want to calculate the minimum value across the first dimension of the tensor: essentially torch.min(t, dim=1). But doing this for the entire shape of the tensor seems extremely wasteful when you’re only really looking for the min of 2-4 values out of tens of thousands.

I already need to grab the elements into an array for separate processing, which flattens the tensor into a 1d array, something similar to : a=~torch.isnan(b). I can also calculate how many real entries were in each column, so I’ll get another tensor which gives me how many elements I need to take the min of for each row: [3, 2, 4, 2, 4, 3, 2, …]. The expected value for all of these is roughly similar. I can convert this to index ranges by taking the cumulative sum [0, 3, 5, 9, 11…], so the first row of the original tensor is now stored in elements 0:3 of the new tensor, the second row from 3:5, third 5:9 etc…

After doing all of this, essentially what I’m looking to do is some parallel operation of torch.min where I feed in the ranges for each kernel to check through as a tensor. For example min_fancy(a, [0, 3, 5, 9, 11…]) would go through elements 0:3 and find the min and store it in row 0, go through 3:5 and store it in row 1, etc. I imagine this will involve writing a custom c++ extension, but I figured I should ask if there is any built in method I am missing or if anyone has any better ideas of how to do this, since I’m not the most familiar with writing pytorch extensions.