Aot_export_module captures more forward variables/parameters

Hey folks,

I want to capture the forward and backward computation of the swin-transformer.
The code below tries to do that and I encountered the following issue: the number of parameters of the model is 173 while the number of variables decomposed in the forward function is 186 - 2 (1 for target and 1 for inputs).
The discrepancy was reproduced by other visual transformers as well and it scaled with the depths parameter of swin-transformer. Could you help me somehow ?

import torchvision
from torch._functorch.aot_autograd import aot_export_module
from torch._functorch.partitioners import default_partition
import torch.fx as fx
import torch
import torch.nn as nn

num_classes = 1000

user_defined_model = torchvision.models.swin_transformer.SwinTransformer(patch_size= [4,4], embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 8, 12], window_size=[7,7],num_classes=num_classes).cuda()

class WrappedModel(nn.Module):
    def __init__(self,nn_model,num_classes):
        super(WrappedModel, self).__init__()
        self.model = nn_model
        self.num_classes = num_classes

    def forward(self, x, target):
        output = self.model(x)
        target_onehot_float = torch.nn.functional.one_hot(target,self.num_classes).float()
        loss = nn.CrossEntropyLoss()(output.float(),target_onehot_float)  
        return (loss, output.detach())

counter = 0
for name, para in user_defined_model.named_parameters():
    if para.requires_grad:
        counter += 1
print(f"Number model parameters {counter}") #173
inputs = torch.randn(size=(1,3,224,224), requires_grad=True).cuda()
target = torch.randint(high=num_classes,size=(1,),dtype=torch.long).cuda()

model = WrappedModel(user_defined_model,num_classes)
m, _ = aot_export_module(model, [inputs,target], trace_joint=True, \
                             output_loss_index=0, decompositions=None)
fwd, bwd = default_partition(m, [inputs, target], num_fwd_outputs=1)

Libraries: Conda env.
-nvcc: Build cuda_11.8.r11.8/compiler.31833905_0
-python 3.8

I also tried with the latest PyTorch 2.5

There are also some buffers in the model. if I include this code in your script:

for name, para in user_defined_model.named_buffers():
     buffer_counter += 1

I get:

Number model parameters 173
Number model buffers 12


yes you are right, they might come from PyTorch register_buffer() function. I found a way to compile it properly but had to remove some essential parts of the architecture. I think there’s more work to do.