Any chance to torch.export DTensor Module

Hi. Recently, I’m using torch.export to export a DNN module expressed in DTensor.
It raises the following error:

RuntimeError: aot_export is not currently supported with traceable tensor subclass.

To repro the error, you may need to slightly change the detect_fake_mode API in torch/_guards.py to enable it to detect FakeTensorMode in wrapper tensor subclasses, like

def detect_fake_mode(inputs: Any = None):
    ...
    for i, flat_input in enumerate(flat_inputs):
        from torch.utils._python_dispatch import is_traceable_wrapper_subclass
        if isinstance(flat_input, FakeTensor):
            fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
        elif is_traceable_wrapper_subclass(flat_input):
            attrs, _ = flat_input.__tensor_flatten__()
            for attr in attrs:
                inner_tensor = getattr(flat_input, attr)
                if isinstance(inner_tensor, FakeTensor):
                    fake_modes.append((inner_tensor.fake_mode, "fake inner tensor input", i))

I am posting this post to ask is there any chance to torch.export nn.Module expressed in DTensor. If so, may I know a general plan or approach? Thanks ahead for any reply.

Hi @Vremold Thanks for the question here! I wonder what’s your use case with torch.export + DTensor? We haven’t planned on torch.export with DTensor yet.

As of today’s nightly we have torch.compile + DTensor already working, would that work for you?

Thanks for the reply. This is for the following consideration. After we have trained a model expressed in DTensor, compared to torch.compile, I guess torch.export is more suitable for model deployment. That’s why I want to torch.export a DTensor model.