Hi, I’m trying to export the following module via torch.export.export
.
class ShardedScaleModule(nn.Module):
def __init__(self):
super().__init__()
self.mesh = Mesh((0, 1, 2), (3, 1), ('x', 'y'))
self.partition_spec = ('x', None, None, None)
self.device = xm.xla_device()
def forward(self, x: torch.Tensor):
y = 2.0 * x
xs.mark_sharding(y.to(self.device), self.mesh, self.partition_spec, use_dynamo_custom_op=True)
z = 3.0 * y
return z
However, it seems like torch_dynamo
tries to trace through xs.mark_sharding
and it consequently fails.
If this is an incorrect way of exporting a module with arbitrary tensors being marked as sharded, what would be the correct way of doing so?
Thanks!