Reshaping torch.nn.parameter

I’m writing a class that is basically just nn.Linear except it has some reshaping.

class myLinear2(nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))    
    self.bias = torch.nn.Parameter(torch.randn(out_features))
       
  def forward(self, input):
    self.weight = self.weight.reshape(-1, 1, self.weight.shape[-1])
    self.bias = self.bias.reshape(self.bias.shape[0], 1, 1)      
    if input.dim() > 1:
      x, y = input.shape
      if y != self.in_features:
          sys.exit(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
    output = (input @ self.weight.transpose(-1, -2)) + self.bias
    return output

Running a model with this class returns an error:
TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

However, if I put the reshape in an if statement, it works whether or not the if statement is actually true.

# this version works
if self.weight.dim() > 2:      
      self.weight = self.weight.reshape(-1, 1, self.weight.shape[-1])
      self.bias = self.bias.reshape(self.bias.shape[0], 1, 1)

What is going on here?

I don’t think your code would work in the condition as I’m seeing the same error if I increase the number of dimension of self.weight and make sure the code is indeed executed.
I’m not familiar with your use case, but maybe you could avoid re-assigning the reshaped tensor and just use it in the calculation which could avoid this error?
If not, you could try to wrap the assignment into a no_grad() guard and use e.g. self.weight.copy_ to avoid creating new nn.Parameters as you would then e.g. need to pass them to an optimizer again.

Thanks I’ll try this out. My use case is a little unique since it actually has to do with pyro ppl. I actually have a forum post on their forum (Reproducible example) Proper shape for Predictive with parallel=True - Pyro Discussion Forum. I was trying to figure out if I could use reshaping to get proper output for the various weight shapes I describe there.

Let’s simplify the problem:

class Foo(nn.Module):
  def __init__(self):
    super().__init__()
    self.p = nn.Parameter(torch.randn(1))
       
  def forward(self, input):
    self.p = self.p.reshape(1, -1)
    return input * self.p

x = torch.randn(5)
layer = Foo()
y = layer(x)

This will not work, and your error message hints at why. A simple test we could do:

def forward(self, input):
    print(type(self.p))
    print(type (self.p.reshape(1,-1)))
    self.p = self.p.reshape(1, -1)
    return input * self.p

This will show:

<class 'torch.nn.parameter.Parameter'>
<class 'torch.Tensor'>

So calling reshape on a Parameter returns the tensor. Then one could think of simply wrapping an nn.Parameter around self.p.reshape(). While that will resolve the error, it would be a new parameter untracked by the optimizer and thus of no use.

One approach to try: we can modify the underlying data tensor.

def forward(self, input):
    self.p.data = self.p.data.reshape(1, -1)
    return input * self.p

Or, we could try not modifying the parameter at all, and just storing in a temporary variable.

def forward(self, input):
    p = self.p.reshape(1, -1)
    return input * p

Of course, it would be the best if no modification was done on self.p inside forward at all, and the right shape was assigned in __init__.