When inspecting the called PyTorch functions with TorchDispatchMode, I sometimes get aten::detach
(depending on the model), which is not documented here: IRs — PyTorch 2.7 documentation
Reproducer:
import torch
import torch.nn as nn
from torch.utils._python_dispatch import TorchDispatchMode
class MyDispatch(TorchDispatchMode, list):
def __torch_dispatch__(self, func, types, args, kwargs=None):
print(func.name())
return func(*args, **kwargs)
model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.ReLU())
x = torch.rand(1, 1, 32, 32)
with MyDispatch():
model(x)
My questions:
- What is
aten::detach
? - Why does this code call
aten::detach
? - Why does the code not call
aten::detach
for other models, for examplemodel = nn.Sequential(nn.ReLU(), nn.Conv2d(1, 1, 1))
where the order ofReLU
andConv2d
is reversed compared to the reproducer model or when building my ownReLU
withtorch.maximum
?