Hey folks,
I want to capture the forward and backward computation of the swin-transformer.
The code below tries to do that and I encountered the following issue: the number of parameters of the model is 173 while the number of variables decomposed in the forward function is 186 - 2 (1 for target and 1 for inputs).
The discrepancy was reproduced by other visual transformers as well and it scaled with the depths parameter of swin-transformer. Could you help me somehow ?
import torchvision
from torch._functorch.aot_autograd import aot_export_module
from torch._functorch.partitioners import default_partition
import torch.fx as fx
import torch
import torch.nn as nn
num_classes = 1000
user_defined_model = torchvision.models.swin_transformer.SwinTransformer(patch_size= [4,4], embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 8, 12], window_size=[7,7],num_classes=num_classes).cuda()
class WrappedModel(nn.Module):
def __init__(self,nn_model,num_classes):
super(WrappedModel, self).__init__()
self.model = nn_model
self.num_classes = num_classes
def forward(self, x, target):
output = self.model(x)
target_onehot_float = torch.nn.functional.one_hot(target,self.num_classes).float()
loss = nn.CrossEntropyLoss()(output.float(),target_onehot_float)
return (loss, output.detach())
counter = 0
for name, para in user_defined_model.named_parameters():
if para.requires_grad:
counter += 1
print(f"Number model parameters {counter}") #173
inputs = torch.randn(size=(1,3,224,224), requires_grad=True).cuda()
target = torch.randint(high=num_classes,size=(1,),dtype=torch.long).cuda()
model = WrappedModel(user_defined_model,num_classes)
m, _ = aot_export_module(model, [inputs,target], trace_joint=True, \
output_loss_index=0, decompositions=None)
fwd, bwd = default_partition(m, [inputs, target], num_fwd_outputs=1)
fwd.to_folder(folder="forward_vit",module_name="Forward")
bwd.to_folder(folder="backward_vit",module_name="Backward")
Libraries: Conda env.
-pytorch-2.4.0-py3.8_cuda11.8_cudnn9.1.0_0
-nvcc: Build cuda_11.8.r11.8/compiler.31833905_0
-python 3.8
I also tried with the latest PyTorch 2.5