Export module to StableHLO with communication collectives


I’m trying to export a module to StableHLO with the sharding specification and also the communication collectives inserted.
I’m using torch_xla 2.2.
This is what my module looks like

import torch
import  torch.nn as nn
import torch_xla

class ScaleModule(nn.Module):
    def __init__(self, scale):
        self.scale = scale

    def forward(self, input: torch.Tensor):
        result = self.scale * input
        return result

And I’m performing the following steps

from torch_xla.distributed.spmd import xla_sharding as xs
from torch_xla.distributed.spmd import Mesh
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
import torch_xla.core.xla_model as xm

input_tensor = torch.randn([8, 1, 8, 16])
mesh = Mesh((0, 1, 2, 3, 4, 5, 6, 7), (2, 4), ('x', 'y'))
partition_spec = ('x', None, None, None)
sharded_tensor = xs.mark_sharding(input_tensor.to(xm.xla_device()), mesh, partition_spec)
exported = export(model.to(xm.xla_device()), args=(sharded_tensor.global_tensor, ))
stable = exported_program_to_stablehlo(exported)

I do see the sharding information on the tensor as part of the mhlo.sharding attribute in the generated StableHLO, however I don’t see the module/operations getting sharded based on this information. From the user guide, it looks like the XLA SpmdPartitioner pass is responsible for this, however I’m not sure what invokes this pass.

I’d appreciate any help. Thanks!

The stablehlo you generate still just captures free compilation IR with sharding annotations. The SPMD partitioning pass happens during runtime, as part of the XLA compilation pass.

Try printing/accessing the output of your exported program, with this environment variable set: XLA_FLAGS="--xla_dump_to={dir_path_to_your_hlo_dump}" You should be able to post optimization pass HLO, which should contain the results of SPMD partitinoing.

Hope this helps.

1 Like

I see the following StableHLO.

module @IrToHlo.6 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<f32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<8x1x8x16xf32> {mhlo.sharding = "{devices=[2,1,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}"}) -> tensor<8x1x8x16xf32> {
    %0 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<8x1x8x16xf32>
    %1 = stablehlo.multiply %arg1, %0 : tensor<8x1x8x16xf32>
    return %1 : tensor<8x1x8x16xf32>

I’m not sure I follow this sentence

You should be able to post optimization pass HLO, which should contain the results of SPMD partitinoing.

I do realize that I need to run some pass to obtain an IR post SPMD partitioning. Although I’m not sure how to do that.

The StableHLO is pre-optimization pass in XLA. PyTorch/XLA takes it and converts to HLO and then compile to an XLA executable. During the compilation we run SPMD partitioning. So I was suggesting that:

  1. your StableHLO won’t show the sharding results,
  2. you should dump XLA post-optimization HLO to see the sharding results.
1 Like

By compilation, I assume you mean running torch.compile on the exported program that contains the sharding annotations on the tensor?

In the post optimization file, I see the following

HloModule SyncTensorsGraph.3, entry_computation_layout={(f32[8,1,8,16]{3,2,1,0})->(f32[8,1,8,16]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true}

ENTRY SyncTensorsGraph.3 {
  p0.1 = f32[8,1,8,16]{3,2,1,0} parameter(0), sharding={replicated}
  copy = f32[8,1,8,16]{3,2,1,0} copy(p0.1)
  ROOT tuple = (f32[8,1,8,16]{3,2,1,0}) tuple(copy)

Is there a way to convert the above IR to StableHLO?
To give you some context as to why I’m asking this, I’m trying to use XLA for export to StableHLO alone. I have my own compilation flow that consumes StableHLO and compiles it for a custom backend.
And I was wondering if there was a way to have StableHLO represent a module post SPMD so that my custom compilation flow can consume it.

I appreciate the help! :slight_smile:

Hi @hsnbrg ,

torch.compile is what gets you to the IR. The XLA compilation to XLA executable happens later via PyTorch/XLA (torch_xla). “The above IR” is the HLO. And to you additional questions,

  • yes, you can try convert HLO to MHLO and then to StableHLO. That should work.

  • the StableHLO you get via export, won’t capture the post XLA optimization result. I think you should go through XLA compilation to get one. This may be a bit hacky for your workflow… unfortunately, we don’t have an API for that.

1 Like

Just to make sure I’m understanding this correctly, torch.compile gives me HLO (hopefully respecting the sharding annotations on the exported program that I pass to it).
However, there is another compilation flow via PyTorch/XLA (torch_xla) that potentially consumes HLO? If so, how might I use this compilation flow?

Again, thank you!

Also, FWIW, I get the same HLO using torch.compile on exported programs pertaining to the sharded vs the unsharded case.
Is there no way to make torch.compile obey the sharding annotations even if the underlying device doesn’t conform to it?

Let me summarize some key points…

  • torch.compile is not XLA compilation – but torch.compile with openxla backend will call XLA compilation eventually.
  • Having said that, “The StableHLO is pre-optimization pass in XLA. PyTorch/XLA takes it and converts to HLO and then compile to an XLA executable.” the compilation here is XLA compilation, doesn’t involve torch.compile.
  • you won’t get the HLO post-optimization, as I explained in the above comments. Currently, the only way is for you to use --xla_dump_to.
  • currently no. XLA compilation is also jit, so happens as part of XLA execution, and must confirm to the current platform.
  • in the future, yes. We are working on an AOT compilation service in PyTorch/XLA.
1 Like

Noted. Thank you! I’m fiddling around with hlo-opt to learn more about what you just described.