How does torch.jit trace submodules with "unsupported" input type

Hi, I have a model with two submodules A and B.
A should not be traced (control flow depends on the input tensor) and B contains “unsupported” input types.

Tracing is ideal for code that operates only on Tensor s and lists, dictionaries, and tuples of Tensor s. torch.jit.script — PyTorch 1.10.1 documentation

How can I trace only on B? (or trace the whole model but skip submodule A? torch.jit.ignore does not work for this case)

I think it’s technically possible since torch.jit.trace can correctly trace submodules with “unsupported” input types. (see comment1 below)

import torch
import torch.nn as nn

class Foo(nn.Module):
    def __init__(self):
        self.identity = nn.Identity()

    def forward(self, x1, x2):
        out = self.identity(x1)
        if x2 is not None:
            out = out + self.identity(x2)
        return out

class Bar(nn.Module):
    def __init__(self):
        self.foo1 = Foo()
        self.foo2 = Foo()

    def forward(self, x):
        out = self.foo1(x, x)
        out = out + self.foo2(x, None)
        return out

bar = Bar()
data = torch.ones(42)

# comment1: This is ok
# bar = torch.jit.trace(bar, data)

# comment2: This is ok too
# bar.foo1 = torch.jit.trace(bar.foo1, (data, data))

# comment3: Fails
bar.foo2 = torch.jit.trace(bar.foo2, (data, None))
RuntimeError: Type 'Tuple[Tensor, NoneType]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

Any feedback would be greatly appreciated!

torch version: 1.10.1

Well, I found a solution that worked for me.
We can first torch.jit.script submodule A (maybe combine with torch.jit.ignore to make it scriptable, just to ensure that it will be skipped by torch.jit.trace), and then torch.jit.trace the whole model.

Still looking for a “solution” for the title.