Are compiled methods, __torch_dispatch__
overrideable? I wanted to get the flop counts for inductor compiled model, and was trying to do something like,
import torchvision.models as models
import torch
inp = torch.randn(1, 3, 224, 224, device=‘cuda’)
mod = torch.compile(models.resnet18().to(‘cuda’))
flop_counter = FlopCounterMode(mod, depth=4)
with flop_counter:
res = mod(inp).sum().backward()
But this causes all sorts of errors, was wondering if there were any work around it.