How to "unregister" a Parameter from a module

Is it possible to unregister a Parameter from an instance of a nn.Module? Let’s say I want to go through all Conv2d layers of a network and replace all weight parameters with my own custom nn.module? I can’t simply re-assign the weight attribute with my own module as I get:

TypeError: cannot assign 'CustomWeight' as parameter 'weight' (torch.nn.Parameter or None expected)

(I don’t want to define a custom Conv2d layer because I don’t want users of my application to change code on their side)

1 Like

Based on the error message it looks like you are trying to assign plain tensors as the parameters. Could you try to warp them in nn.Parameter and wrap the assignment in a with torch.no_grad(): guard?

That doesn’t seem to work:

TypeError: _make_subclass(): argument 'data' (position 2) must be Tensor, not CustomWeight

I think the error is fair. I’m basically trying to replace a Parameter with nn.Module

Yeah, that won’t work.
Could you share the code for CustomWeight?
If that’s not possible, could you explain, how you wanted to assign this module to the conv parameters?
I assume CustomWeight is not just storing the parameters, but also applies some operations. Is that correct?

So imagine at some point in the training I want to multiply my weights by a mask before they’re used in the forward method (I want to freeze the weights and learn the mask instead).

I can of course inherit all the layers I’m interested in and add this mask and override the forward method but I was thinking if I could dynamically remove the weight parameter in the modules and monkey-patch them with something like below:

class CustomWeight(nn.Module):
    def __init__(self, weight):
        self.weight_mask = nn.Parameter(torch.ones_like(weight))
        
        self.weight = weight.clone()
        self.weight.requires_grad = False

    def forward(self, x):
        return self.weight * self.weight_mask

Does this make sense?

Can you remove that from the parameters passed to the optimizer? That way, this specific parameter will not be updated anymore.

It does make sense, but what you want is currently not supported in PyTorch.
There is issue #7313 that discusses this.
I once implemented a little technical exploration (linked from there).
There still is some latent interest in doing this eventually, but sadly It would seem that it’s not a priority.

Apropos spectral_norm in the title of that issue: There you have a similar goal and can see that the current approach is - let’s say - somewhat elaborate.

Best regards

Thomas

1 Like

Is there some reason this doesn’t solve the problem (at least on a case-by-case basis)?

1 Like