Hi, I have a model with two submodules A and B.
A should not be traced (control flow depends on the input tensor) and B contains “unsupported” input types.
Tracing is ideal for code that operates only on
Tensor
s and lists, dictionaries, and tuples ofTensor
s. torch.jit.script — PyTorch 1.10.1 documentation
How can I trace only on B? (or trace the whole model but skip submodule A? torch.jit.ignore
does not work for this case)
I think it’s technically possible since torch.jit.trace
can correctly trace submodules with “unsupported” input types. (see comment1
below)
import torch
import torch.nn as nn
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.identity = nn.Identity()
def forward(self, x1, x2):
out = self.identity(x1)
if x2 is not None:
out = out + self.identity(x2)
return out
class Bar(nn.Module):
def __init__(self):
super().__init__()
self.foo1 = Foo()
self.foo2 = Foo()
def forward(self, x):
out = self.foo1(x, x)
out = out + self.foo2(x, None)
return out
bar = Bar()
data = torch.ones(42)
# comment1: This is ok
# bar = torch.jit.trace(bar, data)
# comment2: This is ok too
# bar.foo1 = torch.jit.trace(bar.foo1, (data, data))
# comment3: Fails
bar.foo2 = torch.jit.trace(bar.foo2, (data, None))
RuntimeError: Type 'Tuple[Tensor, NoneType]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced
Any feedback would be greatly appreciated!
torch version: 1.10.1