I’m doing some graph editing using torch.fx and timm liblary.
I was editing already resnet18 and it was completely successful.
But when using mobilenet I run into problem cause obtained graph doesn’t have some keys already that our original model had.
As an outcome of that our orignal model is fitting good (as supposed) but the obtained graph is not fitting.
Example
from torch import fx
import timm
model = timm.create_model("mobilenetv2_035", num_classes=10)
fx_module = fx.symbolic_trace(model)
print(len(model.state_dict().keys()))
print(len(fx_module.state_dict().keys()))
print(set(model.state_dict().keys()).difference(fx_module.state_dict().keys()))
314
262
{'blocks.1.0.bn3.num_batches_tracked', 'blocks.3.0.bn3.num_batches_tracked', 'blocks.2.2.bn2.num_batches_tracked', 'blocks.5.1.bn3.num_batches_tracked', 'blocks.3.2.bn1.num_batches_tracked', 'blocks.4.2.bn3.num_batches_tracked', 'blocks.2.0.bn3.num_batches_tracked', 'blocks.5.1.bn2.num_batches_tracked', 'blocks.4.2.bn2.num_batches_tracked', 'blocks.4.1.bn1.num_batches_tracked', 'blocks.3.1.bn3.num_batches_tracked', 'blocks.1.1.bn3.num_batches_tracked', 'bn1.num_batches_tracked', 'blocks.4.2.bn1.num_batches_tracked', 'bn2.num_batches_tracked', 'blocks.4.0.bn1.num_batches_tracked', 'blocks.4.0.bn3.num_batches_tracked', 'blocks.5.1.bn1.num_batches_tracked', 'blocks.4.1.bn2.num_batches_tracked', 'blocks.4.0.bn2.num_batches_tracked', 'blocks.5.0.bn1.num_batches_tracked', 'blocks.5.2.bn1.num_batches_tracked', 'blocks.3.3.bn2.num_batches_tracked', 'blocks.5.2.bn3.num_batches_tracked', 'blocks.2.2.bn1.num_batches_tracked', 'blocks.5.0.bn3.num_batches_tracked', 'blocks.3.1.bn1.num_batches_tracked', 'blocks.2.1.bn1.num_batches_tracked', 'blocks.4.1.bn3.num_batches_tracked', 'blocks.2.2.bn3.num_batches_tracked', 'blocks.0.0.bn2.num_batches_tracked', 'blocks.3.3.bn1.num_batches_tracked', 'blocks.3.2.bn3.num_batches_tracked', 'blocks.0.0.bn1.num_batches_tracked', 'blocks.1.1.bn1.num_batches_tracked', 'blocks.3.2.bn2.num_batches_tracked', 'blocks.3.1.bn2.num_batches_tracked', 'blocks.2.0.bn1.num_batches_tracked', 'blocks.3.0.bn2.num_batches_tracked', 'blocks.2.1.bn3.num_batches_tracked', 'blocks.1.1.bn2.num_batches_tracked', 'blocks.3.0.bn1.num_batches_tracked', 'blocks.6.0.bn3.num_batches_tracked', 'blocks.5.0.bn2.num_batches_tracked', 'blocks.2.0.bn2.num_batches_tracked', 'blocks.6.0.bn1.num_batches_tracked', 'blocks.2.1.bn2.num_batches_tracked', 'blocks.6.0.bn2.num_batches_tracked', 'blocks.1.0.bn2.num_batches_tracked', 'blocks.3.3.bn3.num_batches_tracked', 'blocks.5.2.bn2.num_batches_tracked', 'blocks.1.0.bn1.num_batches_tracked'}
How could I preserve the “num_batches_tracked” properties in the obtained graph module?
Maybe it could be achieved using torch._dynamo or smth?