Jit.trace pass custom objects

Hi everyone,
I am wondering whether there is a way to pass custom objects to forward() when using jit.trace. I use custom objects to isolate some behaviour from the modules I pass my object to. For example some modules in my model can work with either Tensors or dictionaries of tensors as long as view() is supported. I implemented a wrapper that provides element-wise view() for dictionaries of tensors.

Here are some examples that show what I want to achieve and what the problem is:

Example 1:

def foo(x, y):
    return 2 * x["a"] + y


inputs_bar = ({"a": torch.rand(3)}, torch.rand(3))
traced_foo = torch.jit.trace(foo, inputs_bar)

This works fine as dict is supported.

Example 2

class Bar:
    def __init__(self, a):
        self.a = a

    def return_modified_content(self):
        return self.a + 10


def bar(x: Bar, y):
    return 2 * x.return_modified_content() + y

inputs_bar = (Bar(torch.rand(3)), torch.rand(3))
traced_bar = torch.jit.trace(bar, inputs_bar)

Results in

traced = torch._C._create_function_from_trace(
RuntimeError: Type 'Tuple[__torch__.Bar, Tensor]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

From the error I conclude that for some reason arguments must be of some default container types.

Example 3

class Bar2(dict):
    def __init__(self, a):
        super().__init__({"a": a})
        self.a = a

    def return_modified_content(self):
        return self.a + 10


def bar2(x: Bar2, y):
    print(type(x))
    return 2 * x.return_modified_content() + y


inputs_bar2 = (Bar2(torch.rand(3)), torch.rand(3))

traced_bar2 = torch.jit.trace(bar2, inputs_bar2)

outputs <class 'dict'> and throws error AttributeError: 'dict' object has no attribute 'return_modified_content'. It seems like Bar2 is converted to dict.

Is there a way to make this work?

Additional question: Anyone know where to find a roadmap for the JIT?