Can't assign to parameter, but is actually None

I have tried to recreate the error that I get for a code in this simple code snippet.
Why does the following code throw an error?
If you run this code, it will run into an error the second time it goes into the training loop on line 18 (“self.matrix2 = result”) that “TypeError: cannot assign ‘torch.FloatTensor’ as parameter ‘matrix2’ (torch.nn.Parameter or None expected)”. My issue is that when the code enters the if block, we are already sure that self.matrix2 is None. So why is this error happening?

Solving the error is simple and it can be done by simply using the commented line in the function miniBatchStep. My issue is just I don’t completely get why this error is happening.

import torch
import torch.nn as nn


def func(vv):
    uu = [vv[:, 0:1], vv[:, 1:2]]
    return torch.hstack(uu)

class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.matrix = nn.Parameter(torch.randn(2, 2), requires_grad=True)
        self.matrix2 = None

    def getModifiedParam(self):
        if self.matrix2 is None:
            result = func(self.matrix)
            self.matrix2 = result
        return self.matrix2

    def forward(self, x):
        return x @ self.getModifiedParam()



    def miniBatchStep(self):
        self.matrix2 = None
        with torch.no_grad():
            self.matrix.data = self.getModifiedParam()
        self.matrix2 = self.matrix
        # self.matrix2 = self.matrix.clone()


def main():
    net = Test()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.1)
    for _ in range(2):
        x = torch.randn(2, 2)
        net(x).sum().backward()
        optimizer.step()
        with torch.no_grad():
            optimizer.zero_grad()
        net.miniBatchStep()

if __name__ == "__main__":
    main()

I have got the error on python 3.11.3 and torch 2.2.1 on mac, and python 3.8.10 and torch 2.1.2+cu118 on ubuntu 20.04.4 LTS.

Hi Taha!

The short story is that Modules treat Parameters specially.

For example, quoting from the Parameter documentation:

Parameters are Tensor subclasses, that have a very special property when used with Module s - when they’re assigned as Module attributes they are automatically added to the list of its parameters

In particular, once a property of a Module has a Parameter assigned to it,
the Module keeps track of this and going forward you can’t assign anything
other than a Parameter or None to that property.

Consider this example script:

import torch
print (torch.__version__)

class MyClass():
    def __init__ (self):
        pass

class MyModule (torch.nn.Module):
    def __init__ (self):
        super().__init__()

myClass = MyClass()
myModule = MyModule()

t = torch.zeros (3)
p = torch.nn.Parameter (torch.ones (3))

myClass.s = p       # no constraint on property s
print (type (myClass.s))
myClass.s = None
print (type (myClass.s))
myClass.s = t       # no issue
print (type (myClass.s))

myModule.s = p      # property s must from now on be a Parameter or None
print (type (myModule.s))
myModule.s = None   # type of s is NoneType
print (type (myModule.s))
myModule.s = t      # raises TypeError

Here is its output:

2.2.1
<class 'torch.nn.parameter.Parameter'>
<class 'NoneType'>
<class 'torch.Tensor'>
<class 'torch.nn.parameter.Parameter'>
<class 'NoneType'>
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 33, in <module>
  File "<path_to_pytorch_install>\torch\nn\modules\module.py", line 1708, in __setattr__
    raise TypeError(f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
TypeError: cannot assign 'torch.FloatTensor' as parameter 's' (torch.nn.Parameter or None expected)

The first time through, this line assigns a Tensor to self.matrix2. At this
point self.matrix2 has not yet been a Parameter, so the assignment
succeeds.

self.matrix is a Parameter so self.matrix2 = self.matrix assigns
a Parameter to self.matrix2. When you then try to execute
getModifiedParam() a second time, the line self.matrix2 = result
tries to assign a Tensor to what is now flagged to be either a Parameter
or None.

However, the .clone() of a Parameter is just a Tensor (rather than still
a Parameter), so when you use the commented-out line, you avoid ever
assigning a Parameter to self.matrix2 and no error occurs.

Best.

K. Frank

Ah, ok. Thanks a lot. Now it makes sense.
Once a parameter is created inside a child of nn.Module, it will always treat that variable name as a parameter.

Thanks a lot.