CNN model parameters are not storing properly

The task at work is to make a CNN model to do some classification task on images. Also, I should be able to view the feature maps after making a classification on an image i.e the images obtained after applying convolution or pooling operations. Below is how I defined my CNN class:

class ConvNet(nn.Module):
  def __init__(self, input_channels, output_dim):
    super().__init__()
    # input 48x48
    self.architecture = {
        "conv1": self.convblock(input_channels, 128, (3,3)), # 46x46
        "conv2" : self.convblock(128, 64, (3, 3), bnorm=True), # 44x44
        "pool1" : self.poolblock((2,2)), # 22x22
        "conv3" : self.convblock(64, 16, (3,3), stride=2), #10x10
        "conv4" : self.convblock(16, 10, (3,3)), # 8x8
        "pool2" : self.poolblock((2,2), bnorm=10), # 4x4
        "feedforward" : nn.Sequential(
          nn.Flatten(), # 4x4x10 = 160
          nn.Linear(160, 128), # 128
          nn.ReLU(inplace=True),
          nn.Dropout(0.3),
          nn.Linear(128, output_dim), # 3
          nn.Softmax(dim=1)
        )                  
    }
    self.maps = {}

  def forward(self, x):
    image = x
    for name, layer in self.architecture.items():
      out = layer(image)
      self.maps[name] = out
      image = out
    return image
     
  def convblock(self, inp, out, kernel, stride=1, bnorm=False):
    if bnorm:
      return nn.Sequential(
        nn.Conv2d(inp, out, kernel, stride=stride),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out)
      )
    else:
      return nn.Sequential(
        nn.Conv2d(inp, out, kernel, stride=stride),
        nn.ReLU(inplace=True)
      )

  def poolblock(self, kernel, bnorm=None):
    if bnorm is None:
      return nn.MaxPool2d(kernel)
    else:
      return nn.Sequential(
          nn.MaxPool2d(kernel),
          nn.BatchNorm2d(bnorm)
      )

  def get_map(self, im, layer):
    fig, ax = plt.subplots(1,2, figsize=(20,10), gridspec_kw={'width_ratios': [1,3]})
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].imshow(im.reshape(im.shape[-2],im.shape[-1],1), cmap="gray") # Shows Input image
    self(im)
    map = self.maps[layer]
    map=map.reshape(map.shape[1],1,map.shape[-2],map.shape[-1])
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    rows = max(int(map.shape[0]/8), 8)
    ax[1].imshow(make_grid(map,nrow=rows).permute(1, 2, 0)) # Shows all the channels after an operation.

The idea is to store blocks of convolution and pooling layers in the self.architecture dictionary with names such as ‘conv1’, ‘conv2’, ‘pool1’ etc…
Then, in the forward method, I would run the input images through each block and also store the outputs of each block in the self.maps dictionary to retrieve later (self.get_map does this).

The problem is that the model’s parameters are not being set properly. Below is the code where I instantiated the model and optimizer:

model = ConvNet(1, 3)
adam_opt = torch.optim.Adam(model.parameters(), lr=0.01)

However I ran into the following error:

/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py in __init__(self, params, defaults)
    271         param_groups = list(params)
    272         if len(param_groups) == 0:
--> 273             raise ValueError("optimizer got an empty parameter list")
    274         if not isinstance(param_groups[0], dict):
    275             param_groups = [{'params': param_groups}]

ValueError: optimizer got an empty parameter list

I printed out the parameter list and its empty. I don’t understand why this is happening.

Is my way of defining the architecture or forward method wrong and pytorch expects some specific behaviour when subclassing nn.Module? If so what is it and how can I change my class? Any extra information on how pytorch actually stores the class parameters is most welcomed.

Thank you.

self.architecture is a plain Python dict and will thus not properly register the modules.
Use nn.ModuleDict and it should work.

Thanks @ptrblck. Can you detail on how this works? Docs on nn.ModuleDict says that it registers the dictionary but I don’t know what this means.

Yes, nn.ModuleDict will make sure to call register_buffer or register_parameter appropriately and make sure these parameters and buffers are properly registered. This will allow you to return them via model.parameters()/.buffers() and will also move them to the desired device, dtype etc. as seen here:

class MyWrongModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.not_registered = {}
        self.not_registered["fc"] = nn.Linear(1, 1)
        
    def forward(self, x):
        out = self.not_registered["fc"](x)
        return out
    
model = MyWrongModel()
x = torch.randn(1, 1)
out = model(x)

print(dict(model.named_parameters()))
# {}

print(model.not_registered["fc"].weight)
# Parameter containing:
# tensor([[0.3274]], requires_grad=True)

model.to("cuda")
print(model.not_registered["fc"].weight)
# Parameter containing:
# tensor([[0.3274]], requires_grad=True)


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.registered = nn.ModuleDict()
        self.registered["fc"] = nn.Linear(1, 1)
        
    def forward(self, x):
        out = self.registered["fc"](x)
        return out
    
model = MyModel()
x = torch.randn(1, 1)
out = model(x)

print(dict(model.named_parameters()))
# {'registered.fc.weight': Parameter containing:
# tensor([[0.6348]], requires_grad=True), 'registered.fc.bias': Parameter containing:
# tensor([-0.3697], requires_grad=True)}

print(model.registered["fc"].weight)
# Parameter containing:
# tensor([[0.6348]], requires_grad=True)

model.to("cuda")
print(model.registered["fc"].weight)
# Parameter containing:
# tensor([[0.6348]], device='cuda:0', requires_grad=True)