Error while creating an object of TorchDispatchMode

Hello,

I’m importing torch dispatch mode like so: from torch.utils._python_dispatch import TorchDispatchMode.

I’m getting the below error when I create an object of torchdispatch mode. It would be great if someone could help me figure out the right way of creating the object for my version of PyTorch.

I’m using PyTorch v1.12.0a0+02fb0b0f.nv22.06 on a Jetson Orin device

Code :


class FlopCounterMode(TorchDispatchMode):
    def __init__(self, module = None):
        self.flop_counts = defaultdict(lambda: defaultdict(int))
        self.parents = ['Global']
        if module is not None:
            for name, module in dict(mod.named_children()).items():
                module.register_forward_pre_hook(self.enter_module(name))
                module.register_forward_hook(self.exit_module(name))


flop_counter = FlopCounterMode(mod)    # this line is causing the error
with flop_counter:
  optmizer.zero_grad()
  outputs = mod(inp)