Scalar parameters not showing up in model.parameters()

I am getting ValueError: optimizer got an empty parameter list when I run this code. What am I missing? Tried using nn.Parameters and still getting error on model parameters a and b. I just want to track and update these two scalars on this toy problem.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

class RosenbrockFunction2N(nn.Module):
  def __init__(self):
    super(RosenbrockFunction2N, self).__init__()
    self.a = torch.rand(1, requires_grad=True)
    self.b = torch.rand(1, requires_grad=True)

  def forward(self, x, y):
    return (self.a - x)**2 + self.b(y-x**2)**2

# Dataset
inputs = [[1., 3.], [5., 7.], [2., 4.], [3., 9.]]
targets = []
for i in inputs:
  targets.append([(1-i[0]**2) + (i[1] - i[0]**2)**2])
inputs = torch.tensor(inputs)
targets = torch.tensor(targets)

train_ds = TensorDataset(inputs, targets)
train_dl = DataLoader(train_ds)

# Load model
model = RosenbrockFunction2N()

# Loss and Optimizer
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    
for i in range(5):
  for inputs, target in train_dl:
    optimizer.zero_grad()
    output = model(inputs[0][0], inputs[0][1])
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

Registering the a and b tensors as nn.Parameters should work (alternatively you could also use self.register_parameter). In your current code you are assigning tensors as model attributes, which won’t be registered as parameters internally.

Thanks, @ptrblck. Sorry, I’m not used to working with scalars like this. When I change it to this:

    self.a = nn.Parameter(torch.tensor(1., requires_grad=True))
    self.b = nn.Parameter(torch.tensor(1., requires_grad=True))

I get this:
TypeError: 'Parameter' object is not callable

This error is raised by trying to call the parameter in:

self.b(y-x**2)**2

What kind of operation should be performed in this line of code?

Such a simple problem. Yeah, it’s just a missing multiplication symbol. Thanks @ptrblck !