Timing model modules

I am attempting to time each layer of a network by registering hooks. I then run four inference benchmarks.

  • A control, running the model as is
  • Compiling for the CPU with torch.compile()
  • Running the model on the GPU
  • Comping the model to Triton (GPU) using torch.compile()

However, the introduction of hooks seems to cause the compiler a lot of problems. I get several cache limit hits, meaning torch attempted to recompile functions several times and failed. This leads to the runtime of the torch.compile() benches to be slower than their device counterpart. However, running these same benches without hooks, and with the torch profiler, I get the expected result that the compiled version is normally faster.

My hooks are setup like this:

    def register_hooks(self,model):
        def pre_hook_fn(module, input):
            module.start_time = time.time()

        def forward_hook_fn(module, input, output):
            end_time = time.time()
            rt = end_time - module.start_time
            self.layer_times.append( (str(module), float(rt)) ) #append each layer time to a dict

        for module in model.modules():
            module.register_forward_pre_hook(pre_hook_fn)
            module.register_forward_hook(forward_hook_fn)

I am confused as to why these hooks are causing so many issues with the compiler and its efficiency. The torch profiler is fairly undocumented and cumbersome to use, and does not provide me with the granularity (per layer) that I would like. Any advice would be appreciated, thanks!