Torch.sparse.max()?

I’m looking for a torch.sparse function that could take the max (and ideally argmax) of torch.sparse_coo_tensor, along a specific sparse dimension. I’ve seen in the docs that there is no such function (despite there being a torch.sparse.sum() function). I was wondering if there was an obvious workaround that I’m missing, or if it’s too difficult to implement.

i = torch.tensor([[0, 1, 1],
                  [2, 0, 2]])
v = torch.tensor([3, 4, 5], dtype=torch.float32)

coo_tensor=torch.sparse_coo_tensor(i, v, [2, 4])

max_coo_tensor = coo_tensor.to_dense().max(0)[0] #Works

max_coo_tensor = coo_tensor.max(0)[0] #Error
1 Like