Module weights as function of other parameters

Given a base model, I want to implement another model which recomputes the weights every time (W = W0 * func(W1) ) and uses W as the module weights for computations, but trains only the parameters W0 and W1 (instead of training W).

I have been trying the following, but this doesn’t seem to work:

class AdvModel(nn.Module):
    def __init__(self, model):
        super(AdvModel, self).__init__()
        self.model = model
        for param in select_params(model): # iterates on selected parameters of model
            # (W) :this need not be updated
            param.weight.requires_grad = False 
            # (W0) this will be updated by autograd
            param.register_parameter(name='W0', param=deepcopy(param.weight) ) 
            # (W1) Two cases:
            #  this is updated by autograd; or 
            #  this is updated manually after each iteration
    def forward(self, x):
        for param in select_params(self.model):
            adv_W1 = func_1(param.W1) 
            # where func_1 maybe a non-differentiable/differentiable function
            param.weight = torch.nn.Parameter(param.W0 * adv_W1, 
        return self.model(x)

I’m about confused here, what is the error and is W1 trained in addition to W0? It seems a bit strange to say W1 is being trained if it doesn’t appear to be a leaf node of the computation graph. Additionally, how would W1’s be computed if func_1 as described isn’t differentiable?

The model is trained with the regular cross entropy loss.
When func_1 is non-differentiable, W1 is not trainable (W1 would be set manually in this case before doing the training/inference).
When func_1 is differentiable, W1 will be trainable (because then W1 would be the leaf node).