Torch.fx.symbolic_trace removes some of the keys from module state_dict

I’m doing some graph editing using torch.fx and timm liblary.
I was editing already resnet18 and it was completely successful.
But when using mobilenet I run into problem cause obtained graph doesn’t have some keys already that our original model had.
As an outcome of that our orignal model is fitting good (as supposed) but the obtained graph is not fitting.
Example

from torch import fx
import timm

model = timm.create_model("mobilenetv2_035", num_classes=10)
fx_module = fx.symbolic_trace(model)

print(len(model.state_dict().keys()))
print(len(fx_module.state_dict().keys()))
print(set(model.state_dict().keys()).difference(fx_module.state_dict().keys()))
314
262
{'blocks.1.0.bn3.num_batches_tracked', 'blocks.3.0.bn3.num_batches_tracked', 'blocks.2.2.bn2.num_batches_tracked', 'blocks.5.1.bn3.num_batches_tracked', 'blocks.3.2.bn1.num_batches_tracked', 'blocks.4.2.bn3.num_batches_tracked', 'blocks.2.0.bn3.num_batches_tracked', 'blocks.5.1.bn2.num_batches_tracked', 'blocks.4.2.bn2.num_batches_tracked', 'blocks.4.1.bn1.num_batches_tracked', 'blocks.3.1.bn3.num_batches_tracked', 'blocks.1.1.bn3.num_batches_tracked', 'bn1.num_batches_tracked', 'blocks.4.2.bn1.num_batches_tracked', 'bn2.num_batches_tracked', 'blocks.4.0.bn1.num_batches_tracked', 'blocks.4.0.bn3.num_batches_tracked', 'blocks.5.1.bn1.num_batches_tracked', 'blocks.4.1.bn2.num_batches_tracked', 'blocks.4.0.bn2.num_batches_tracked', 'blocks.5.0.bn1.num_batches_tracked', 'blocks.5.2.bn1.num_batches_tracked', 'blocks.3.3.bn2.num_batches_tracked', 'blocks.5.2.bn3.num_batches_tracked', 'blocks.2.2.bn1.num_batches_tracked', 'blocks.5.0.bn3.num_batches_tracked', 'blocks.3.1.bn1.num_batches_tracked', 'blocks.2.1.bn1.num_batches_tracked', 'blocks.4.1.bn3.num_batches_tracked', 'blocks.2.2.bn3.num_batches_tracked', 'blocks.0.0.bn2.num_batches_tracked', 'blocks.3.3.bn1.num_batches_tracked', 'blocks.3.2.bn3.num_batches_tracked', 'blocks.0.0.bn1.num_batches_tracked', 'blocks.1.1.bn1.num_batches_tracked', 'blocks.3.2.bn2.num_batches_tracked', 'blocks.3.1.bn2.num_batches_tracked', 'blocks.2.0.bn1.num_batches_tracked', 'blocks.3.0.bn2.num_batches_tracked', 'blocks.2.1.bn3.num_batches_tracked', 'blocks.1.1.bn2.num_batches_tracked', 'blocks.3.0.bn1.num_batches_tracked', 'blocks.6.0.bn3.num_batches_tracked', 'blocks.5.0.bn2.num_batches_tracked', 'blocks.2.0.bn2.num_batches_tracked', 'blocks.6.0.bn1.num_batches_tracked', 'blocks.2.1.bn2.num_batches_tracked', 'blocks.6.0.bn2.num_batches_tracked', 'blocks.1.0.bn2.num_batches_tracked', 'blocks.3.3.bn3.num_batches_tracked', 'blocks.5.2.bn2.num_batches_tracked', 'blocks.1.0.bn1.num_batches_tracked'}

How could I preserve the “num_batches_tracked” properties in the obtained graph module?
Maybe it could be achieved using torch._dynamo or smth?

I believe your problem is that num_batches_tracked is not a learnable parameter but is likely some counter TIMM uses to update some batch norm statistics whereas a symbolic trace will only trace learnable parameters, we can confirm this if we run this script

from torch import fx
import timm

model = timm.create_model("mobilenetv2_035", num_classes=10)
fx_module = fx.symbolic_trace(model)

num_learnable_params_model = sum(p.requires_grad for p in model.parameters())

num_learnable_params_fx_module = sum(p.requires_grad for p in fx_module.parameters())

print(num_learnable_params_model) #158
print(num_learnable_params_fx_module) #158

Unsure how relevant this is to the original issue, but note that it’s an internal attribute of the nn.BatchNormXd layer:

bn = nn.BatchNorm2d(3)
print(bn.num_batches_tracked)
# tensor(0)

which is used to update exponential_average_factor if the cumulative moving average is used.

is there a way then to keep the not learnable paramaters in the graph if they’re used for forward computation in the original model?