Hi,
I’m trying to export a module to StableHLO with the sharding specification and also the communication collectives inserted.
I’m using torch_xla
2.2.
This is what my module looks like
import torch
import torch.nn as nn
import torch_xla
torch_xla.runtime.use_spmd()
class ScaleModule(nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, input: torch.Tensor):
result = self.scale * input
return result
And I’m performing the following steps
from torch_xla.distributed.spmd import xla_sharding as xs
from torch_xla.distributed.spmd import Mesh
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
import torch_xla.core.xla_model as xm
input_tensor = torch.randn([8, 1, 8, 16])
mesh = Mesh((0, 1, 2, 3, 4, 5, 6, 7), (2, 4), ('x', 'y'))
partition_spec = ('x', None, None, None)
sharded_tensor = xs.mark_sharding(input_tensor.to(xm.xla_device()), mesh, partition_spec)
exported = export(model.to(xm.xla_device()), args=(sharded_tensor.global_tensor, ))
stable = exported_program_to_stablehlo(exported)
I do see the sharding information on the tensor as part of the mhlo.sharding
attribute in the generated StableHLO, however I don’t see the module/operations getting sharded based on this information. From the user guide, it looks like the XLA SpmdPartitioner
pass is responsible for this, however I’m not sure what invokes this pass.
I’d appreciate any help. Thanks!