Connecting PyTorch sparse tensors with MLIR

Greetings! First time poster, here and a relative newcomer to PyTorch (but already loving it!). I am Aart Bik, and I am the tech lead of the MLIR Sparsifier team at Google. You can find more details about this team following the link, but in a nutshell, we think that sparsity should be a property, viz a type, and not a tedious implementation concern, and that a “sparse compiler”, aka sparsifier, should deal with all low level details of exploiting sparsity. These ideas have been fully worked out in MLIR for a CPU pipeline, with some GPU acceleration ramping up as well.

As such, we are very interested connecting a machine learning “front end” with our MLIR sparsifier “backend” (please forgive the very broad way of using those terms), and torch.sparse seems to fit our philosophy quite well, keeping the operator semantics (more or less) separate from the actual sparsity of their operands.

We have started to explore whether we can connect the torch.sparse data types with the MLIR sparse tensor types (the latter providing a superset of the former, since we use a TACO-flavored way of defining sparse tensor types). For this, torch-mlir seems to make the most sense for us.

One of the tasks that probably requires minor changes inside the PyTorch framework itself is propagating the sparsity layout metadata somehow a bit further than currently seems to be the case, so that we can further propagate the layout while building the tensor IR in the torch dialect tensors of torch-mlir.

In any case, I just wanted to introduce myself and put this project on your radar screen. Please expect a lot of questions to follow. Also, your initial ideas are of course welcome already as well!

Nice “emeeting” you all!

mlir_sparsifier_name

2 Likes

And to make the discussion a bit more concrete, here is my first question.
Given a very simple example:

class BikNet(torch.nn.Module):

  def __init__(self):
    super(BikNet, self).__init__()
    return

  def forward(self, x):
    return x.sum()

Then, building the traced graph for dense input works fine

dense_input = torch.ones(64, 64)
prog = torch.export.export(biknet, args=(dense_input,))

which yields

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, l_x_: "f32[64, 64]"):
            # File: biknet.py:27, code: return x.sum()
            sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_);  l_x_ = None
            return (sum_1,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sum_1'), target=None)])
Range constraints: {}

However, building the traced graph fails for sparse input

sparse_input = dense_input.to_sparse_csr()
prog =  torch.export.export(biknet, args=(sparse_input,))

throws an exception in builder.py when trying to clone.

torch._dynamo.exc.InternalTorchDynamoError: Sparse CSR tensors do not have strides

It would already go a long way if the traced graph builder would actually yield the same IR as shown above, but with the tensor input marked with the torch.sparse_csr layout. From there, we would be able to propagate the sparse type all the way down to MLIR.

Note that there is the subtlety that at runtime, sparse tensor input parameters of course have a 1:N relation with the actual passed in array arguments for positions, indices, and values (and same for a potential sparse output). But that is something we will have to deal with later (and we have done that in the past for e.g. Sparse JAX). For now, just getting the traced graph with sparse types would be really helpful.

Any comments on whether that is something that can be easily done? Or perhaps this is not the right direction, then I would love to hear why not as well.

Note that I also opened feature request 117188 for this.

I hacked a bit around and was able to get a traced graph with the information I am (for now) looking for. In order to do this I had to

  1. add an extra layout field to torch.fx.passes.shape_prop.TensorMetadata
  2. add an extra layout field to torch._subclasses.fake_tensor.FakeTensor
  3. (obviously) modify printing to include those fields
  4. hack my way around to propagate the sparse layout (very dirty solution)

With that, I am able to generate the following traced graph.

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, l_x_: "f32[64, 64]:torch.sparse_csr"):    # ADDED!
            # File: biknet.py:27, code: return x.sum()
            sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_);  l_x_ = None
            return (sum_1,)
           
Graph signature: ExportGraphSignature(
  input_specs=[
      InputSpec(
           kind=<InputKind.USER_INPUT: 1>,
           arg=TensorArgument(name='l_x_'),
           target=None,
           layout=torch.sparse_csr)       # ADDED!
  ],
  output_specs=[
     OutputSpec(
         kind=<OutputKind.USER_OUTPUT: 1>,
         arg=TensorArgument(name='sum_1'),
         target=None)
 ])
Range constraints: {}

With this information, I can prototype propagating this further in torch-mlir and then connecting this with the MLIR Sparsifier.

Stay tuned! I will report back here how that goes.

@aartbik I think you should post in the developer mailing list https://dev-discuss.pytorch.org/ as your topic is really related to talking to the sparse devs and proposing making changes.

discuss.pytorch.org is the user / support forum.

Apologies for starting in the wrong forum and thanks for pointing me to the right place for this discussion. I hope to continue this discussion with a new posting in that forum.