Control flow with torch.rand or random.random (from python) gets traced but result is unexpected (see examples below). I understand that the docs here point out this nondeterministic behavior and mention torch.rand is not traceable. Is silently tracing through such code with nondeterministic behavior reasonable?
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
rand = torch.rand(1).item()
if rand < 0.5:
x = torch.relu(x)
return x
mod = MyModule()
traced_mod = torch.fx.symbolic_trace(mod)
traced_mod.graph.print_tabular()
"""
opcode name target args kwargs
----------- ------ -------- --------- --------
placeholder x x () {}
call_module linear linear (x,) {}
output output output (linear,) {}
or
opcode name target args kwargs
------------- ------ ------------------------------------------------------- --------- --------
placeholder x x () {}
call_module linear linear (x,) {}
call_function relu <built-in method relu of type object at 0x7f14be3091a0> (linear,) {}
output output output (relu,) {}
"""
Wrapping with torch.fx.wrap also shows similar behavior.
import torch
import torch.fx
@torch.fx.wrap
def torch_rand(shape):
return torch.rand(shape)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
rand = torch_rand(1).item()
if rand < 0.5:
x = torch.relu(x)
return x
mod = MyModule()
traced_mod = torch.fx.symbolic_trace(mod)
traced_mod.graph.print_tabular()
"""
opcode name target args kwargs
------------- ------ ------------------------------------------------------- --------- --------
placeholder x x () {}
call_module linear linear (x,) {}
call_function relu <built-in method relu of type object at 0x7f5ef0c251a0> (linear,) {}
output output output (relu,) {}
or
opcode name target args kwargs
----------- ------ -------- --------- --------
placeholder x x () {}
call_module linear linear (x,) {}
output output output (linear,) {}
"""
PyTorch version:
python -c "import torch; print(torch.__version__)"
1.10.2+cu113