How to dynamically modify parameters (specifically weights) during forward pass?

Context:
Progressive Growing Of GANs For Improved Quality, Stability, And Variation
(https:// arxiv .org /pdf/1710.10196.pdf) # I'm a new user, hence can't have more than 2 links

Specifically,


I’ve found two (vaguely) similar questions here:

  1. Changing weight after forward and before backward

  2. Adjusting parameters in the forward pass.

Currently, this is my implementation (which I suspect to be wrong):

class WeightScaledModule(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        self.weight = self.module.weight
    def forward(self, x):
        fan = torch.nn.init._calculate_correct_fan(self.module.weight, 'fan_in')
        self.module.weight = nn.Parameter(self.module.weight * (2/fan)**.5)
        return self.module(x)

Any help on how to “dynamically” adjust parameters would be appreciated!

I implemented it as https://gist.github.com/SsnL/44b082499381478150abfeabaf2701d2, and it seems to work.