PyTorch scatter max for sparse tensors?

I have the following PyTorch code

value_tensor = torch.sparse_coo_tensor(indices=query_indices.t(), values=values, size=(num_lines, img_size, img_size)).to(device=device)
value_tensor = value_tensor.to_dense()
indices = torch.arange(0, img_size * img_size).repeat(len(lines)).to(device=device)
line_tensor_flat = value_tensor.flatten()
img, _ = scatter_max(line_tensor_flat, indices, dim=0)
img = torch.reshape(img, (img_size, img_size))

Note the line: value_tensor = value_tensor.to_dense(), this is unsurprisingly slow.

However, I cannot figure out how to obtain the same results with a sparse tensor. The function in question calls reshape which is not available on sparse tensors. I’m using Scatter Max but opened to using anything that works.