Hi. Recently, I’m using torch.export to export a DNN module expressed in DTensor.
It raises the following error:
RuntimeError: aot_export is not currently supported with traceable tensor subclass.
To repro the error, you may need to slightly change the detect_fake_mode
API in torch/_guards.py
to enable it to detect FakeTensorMode in wrapper tensor subclasses, like
def detect_fake_mode(inputs: Any = None):
...
for i, flat_input in enumerate(flat_inputs):
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if isinstance(flat_input, FakeTensor):
fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
elif is_traceable_wrapper_subclass(flat_input):
attrs, _ = flat_input.__tensor_flatten__()
for attr in attrs:
inner_tensor = getattr(flat_input, attr)
if isinstance(inner_tensor, FakeTensor):
fake_modes.append((inner_tensor.fake_mode, "fake inner tensor input", i))
I am posting this post to ask is there any chance to torch.export nn.Module expressed in DTensor. If so, may I know a general plan or approach? Thanks ahead for any reply.