I’m trying to create a custom tensor class that subclasses torch.Tensor
.
class CustomTensor(torch.Tensor):
@staticmethod
def __new__(cls, x, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)
def __init__(self, x):
super().__init__() # optional
self.foo = None
def set_foo(self, foo: str):
self.foo = foo
def module(t):
b = torch.add(t, torch.tensor([2]))
b = CustomTensor(b)
b.set_foo("hello")
return b
and then I’m attempting to use the torch.compile
function
But i got errors like thiis
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [TensorVariable()] {}
is there a way to implement call_function UserDefinedClassVariable() [TensorVariable()]
in dynamo?