Export module (via torch_dynamo) with arbitrary tensors marked as sharded

Hi, I’m trying to export the following module via torch.export.export.

class ShardedScaleModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.mesh = Mesh((0, 1, 2), (3, 1), ('x', 'y'))
        self.partition_spec = ('x', None, None, None)
        self.device = xm.xla_device()

    def forward(self, x: torch.Tensor):
        y = 2.0 * x
        xs.mark_sharding(y.to(self.device), self.mesh, self.partition_spec, use_dynamo_custom_op=True)
        z = 3.0 * y
        return z

However, it seems like torch_dynamo tries to trace through xs.mark_sharding and it consequently fails.
If this is an incorrect way of exporting a module with arbitrary tensors being marked as sharded, what would be the correct way of doing so?

Thanks!