Need tensor subclass support in torch.export

Hi. I’m using torch.export to export a model with UInt4Tensor from torchao, and I see this error:

aot_export is not currently supported with traceable tensor subclass.

Wondering what is the plan to support tensor subclass with torch.export and any suggestion for my use case? Thanks

Here is my environment and reproduce script:
torch==2.4.0
torchao==0.1

import torch
import numpy as np
from torchao.dtypes.uint4 import UInt4Tensor


class ExampleModel(torch.nn.Module):
    def __init__(self):
        """Init"""
        super().__init__()
        x_uint8 = torch.randint(0, 16, (4, 8)).to(torch.uint8)
        self.x = UInt4Tensor.from_unpacked(x_uint8)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """Invoke."""
        transposed = self.x.view(8, 4).to(torch.uint8)
        return torch.add(input, transposed)


def main():
    input_tensor = torch.randn(8, 4)
    model = ExampleModel()
    with torch.no_grad():
        exported_program = torch.export.export(
            model.eval(), args=(), kwargs={"input":input_tensor},
        )
    print("===exported_program====")
    print(exported_program)

if __name__ == "__main__":
    main()