Hi. I’m using torch.export to export a model with UInt4Tensor from torchao, and I see this error:
aot_export is not currently supported with traceable tensor subclass.
Wondering what is the plan to support tensor subclass with torch.export and any suggestion for my use case? Thanks
Here is my environment and reproduce script:
torch==2.4.0
torchao==0.1
import torch
import numpy as np
from torchao.dtypes.uint4 import UInt4Tensor
class ExampleModel(torch.nn.Module):
def __init__(self):
"""Init"""
super().__init__()
x_uint8 = torch.randint(0, 16, (4, 8)).to(torch.uint8)
self.x = UInt4Tensor.from_unpacked(x_uint8)
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Invoke."""
transposed = self.x.view(8, 4).to(torch.uint8)
return torch.add(input, transposed)
def main():
input_tensor = torch.randn(8, 4)
model = ExampleModel()
with torch.no_grad():
exported_program = torch.export.export(
model.eval(), args=(), kwargs={"input":input_tensor},
)
print("===exported_program====")
print(exported_program)
if __name__ == "__main__":
main()