Hi everyone,
I am wondering whether there is a way to pass custom objects to forward()
when using jit.trace. I use custom objects to isolate some behaviour from the modules I pass my object to. For example some modules in my model can work with either Tensors or dictionaries of tensors as long as view()
is supported. I implemented a wrapper that provides element-wise view()
for dictionaries of tensors.
Here are some examples that show what I want to achieve and what the problem is:
Example 1:
def foo(x, y):
return 2 * x["a"] + y
inputs_bar = ({"a": torch.rand(3)}, torch.rand(3))
traced_foo = torch.jit.trace(foo, inputs_bar)
This works fine as dict
is supported.
Example 2
class Bar:
def __init__(self, a):
self.a = a
def return_modified_content(self):
return self.a + 10
def bar(x: Bar, y):
return 2 * x.return_modified_content() + y
inputs_bar = (Bar(torch.rand(3)), torch.rand(3))
traced_bar = torch.jit.trace(bar, inputs_bar)
Results in
traced = torch._C._create_function_from_trace(
RuntimeError: Type 'Tuple[__torch__.Bar, Tensor]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced
From the error I conclude that for some reason arguments must be of some default container types.
Example 3
class Bar2(dict):
def __init__(self, a):
super().__init__({"a": a})
self.a = a
def return_modified_content(self):
return self.a + 10
def bar2(x: Bar2, y):
print(type(x))
return 2 * x.return_modified_content() + y
inputs_bar2 = (Bar2(torch.rand(3)), torch.rand(3))
traced_bar2 = torch.jit.trace(bar2, inputs_bar2)
outputs <class 'dict'>
and throws error AttributeError: 'dict' object has no attribute 'return_modified_content'
. It seems like Bar2 is converted to dict.
Is there a way to make this work?