At the beginning of my model training, I run a forward and backward under FlopCounterMode context manager to compute the total number of flops. I recently upgraded to torch==2.7 from torch==2.3 and I’m now getting the following error
NotImplementedError: There was no rule registered for HOP triton_kernel_wrapper_mutation and mode <torch.utils.flop_counter._FlopCounterMode object at 0x7f6798e07cb0>. We recommend filing an issue.
Error: SystemExit: worker with local_rank 0 exited with non-zero exitcode 1
I’ve determined this is due to the fact that I have a custom triton kernel in my code. Ideally, I’d like to just dispatch to the usual implementation and not bother with estimating the FLOPs of this kernel (it’s negligble anyway). I tried to use the HigherOrderOperator.py_impl
method but I can’t manage to make it work correctly. Specifically, I’m not sure how exactly to register my implementation for FlopCounterMode (which should just call out to run my triton kernel like normal).
I’m running on python 3.12 with no torch.compile on an H200 GPU.