Using flop_counter with torch.compile

Are compiled methods, __torch_dispatch__ overrideable? I wanted to get the flop counts for inductor compiled model, and was trying to do something like,

import torchvision.models as models
import torch

inp = torch.randn(1, 3, 224, 224, device=‘cuda’)
mod = torch.compile(models.resnet18().to(‘cuda’))

flop_counter = FlopCounterMode(mod, depth=4)
with flop_counter:
res = mod(inp).sum().backward()
But this causes all sorts of errors, was wondering if there were any work around it.