Change nn.Parameter type to nn.Module type

I want to replace the parameters inside a model with a custom nn.Module layer that I created. Below is a very simplified example of the code I desire.

Code:

import torch
import torch.nn as nn

class change_to_layer(nn.Module):
  def __init__(self):
    super().__init__()
    self.w = nn.Parameter(torch.randn(100, 100))
  
  def __mul__(self, other):
    return self.forward(other)
  
  def __rmul__(self, other):
    return self.forward(other)

  def forward(self, x):
    return x @ self.w


class simple_model(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(100, 100)
    self.scale = nn.Parameter(torch.ones(1))
    self.fc2 = nn.Linear(100, 100)

  def forward(self, x):
    x = self.fc1(x)
    x = self.scale * x
    x = self.fc2(x)
    print(x)


model = simple_model()

model.scale = change_to_layer()  # change nn.Parameter to nn.Module (error occurs)

input = torch.randn(100)
print(model(input))

However, I encounter the following error.

Error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-15-1a4295f999b1> in <cell line: 33>()
     31 model = simple_model()
     32 
---> 33 model.scale = change_to_layer()
     34 
     35 input = torch.randn(100)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in __setattr__(self, name, value)
   1633         elif params is not None and name in params:
   1634             if value is not None:
-> 1635                 raise TypeError("cannot assign '{}' as parameter '{}' "
   1636                                 "(torch.nn.Parameter or None expected)"
   1637                                 .format(torch.typename(value), name))

TypeError: cannot assign '__main__.change_to_layer' as parameter 'scale' (torch.nn.Parameter or None expected)

How can I change the type of a class variable?
Or directly change nn.Parameter to nn.Module?

You could delete the parameter first via del model.scale before assigning the new module to this attribute.

1 Like