Autograd FLOP Calculation with Higher Order Derivatives

I’m working on FLOP (floating-point operation) calculation and ran into a challenge with higher-order derivatives when using torch.autograd. Specifically, I’m looking to compute FLOPs for operations that involve autograd.grad(..., create_graph=True) to support higher-order derivative calculations.

I’ve noticed that none of the existing FLOP profiling libraries I’ve tested (like torch.profiler, fvcore, or ptflops) seem to support FLOP calculation when higher-order derivatives are involved. They either don’t account for autograd operations (and just estimate it using the forward pass) or fail when create_graph=True is enabled.

In theory, the computational cost of higher-order derivatives could scale as O(n^2), where n is the number of parameters, but in practice, I’ve observed that it often runs much faster. I’m sure there are optimizations pytorch makes under the hood to avoid to avoid naive calculation of the hessian. Has anyone else explored FLOP profiling for such use cases or can share insights on how autograd handles these operations under the hood to make it faster than the theoretical worst-case? Or, if necessary, is it possible to add support for these types of FLOP calculations?

Thanks