I want to create a learnable parameter when aggregrate several models together

In your current approach the usage of weight is not captured in a differentiable operation since you are manipulating the state_dict directly and load it afterwards.
You could check this post to see if a parametrization would work.