Hello,
I’m importing torch dispatch mode like so: from torch.utils._python_dispatch import TorchDispatchMode
.
I’m getting the below error when I create an object of torchdispatch mode. It would be great if someone could help me figure out the right way of creating the object for my version of PyTorch.
I’m using PyTorch v1.12.0a0+02fb0b0f.nv22.06
on a Jetson Orin device
Code :
class FlopCounterMode(TorchDispatchMode):
def __init__(self, module = None):
self.flop_counts = defaultdict(lambda: defaultdict(int))
self.parents = ['Global']
if module is not None:
for name, module in dict(mod.named_children()).items():
module.register_forward_pre_hook(self.enter_module(name))
module.register_forward_hook(self.exit_module(name))
flop_counter = FlopCounterMode(mod) # this line is causing the error
with flop_counter:
optmizer.zero_grad()
outputs = mod(inp)