Add custom weight for instance normalisation

I’m reimplementing a paper and they use adaptive instance norm but with custom weights comming from another network.

I know that I can code the norm layer from scratch (it’s not long or hard) but I was looking for a cleaner solution.

I didn’t find an adaIN layer in pytorch so :

Can I use the InstanceNorm2D from torch.nn and just pass my weights as attribute of the class (if I cast them as Parameters ?) will the gradient and everything will flow ok ?

Or can I use other solutions to pass custom weights to a norm layer ?

Thanks a lot in advance.