Optimizer.state_dict () for an autoencoder neural network

Hi everyone,
I have developed a convolutional autoencoder neural network for extracting features of an image dataset. The current architecture of the encoder is as follows (the decoder part is the reverse, obviously):

class Encoder (nn.Module):
  def __init__ (self):
    super ().__init__()

    self.encoder_conv = nn.Sequential (
        
        nn.Conv2d (in_channels = 1, out_channels = 8, kernel_size = 3, stride = 2, padding = 1, dilation = 1),
        nn.ReLU (),
        #nn.BatchNorm2d (8),

        nn.Conv2d (in_channels = 8, out_channels = 16, kernel_size = 3, stride = 2, padding = 1, dilation = 1),
        nn.ReLU (),
        #nn.BatchNorm2d (16),

        
        nn.Conv2d (in_channels = 16, out_channels = 16, kernel_size = 3, stride = 2, padding = 1, dilation = 1),
        nn.ReLU (),
        #nn.BatchNorm2d (16),


        
        nn.Conv2d (in_channels = 16, out_channels = 32, kernel_size = 3, stride = 2, padding = 1, dilation = 1),
        nn.ReLU ()
        #nn.BatchNorm2d (32),

        
        #nn.Conv2d (in_channels = 32, out_channels = 32, kernel_size = 3, stride = 2, padding = 1, dilation = 1),
        #nn.ReLU (),
        #nn.BatchNorm2d (32)  
                                    )
    
    self.flatten = nn.Flatten (start_dim = 1)

    self.encoder_lin = nn.Sequential (
        nn.Linear (32 * 8 * 8 , 512),
        nn.ReLU (),
        nn.Linear (512, 128),
        nn.ReLU (),
        nn.Linear (128, 32)
    )


  def forward (self, x):

    x  = self.encoder_conv (x)
    x  = self.flatten (x)
    x = self.encoder_lin (x)
    
    return x
                                                              

The function for the training is as follows:

def train_autoencoder (encoder, decoder, dataloader, loss_fn, optimizer):

  encoder.train ()
  decoder.train ()
  train_loss = []

  for image, label in dataloader:

   
    encoded_features = encoder (image)

    decoded_features = decoder (encoded_features)

  
    loss = loss_fn (decoded_features, image)

    optimizer.zero_grad ()

    loss.backward ()

    optimizer.step ()

    train_loss.append (loss.item ())
    


  return np.mean (train_loss), encoder, decoder, optimizer

Regarding the code provided here, I used the following class for saving the best model parameters:

#Creating a function for saving the best model weights
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_loss=float('inf')
    ):
        self.best_loss = best_loss
        
    def __call__(
        self, current_loss, 
        epoch, model, optimizer
    ):
        if current_loss < self.best_loss:
            self.best_loss = current_loss
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': current_loss,
                }, 'checkpoint.pth')

Finally, I used this code snippet:

#Moving the datasets to GPU beforehand

  train_list = []
  test_list = []
    
  initial_dataloader_train = DataLoader (train_dataset, batch_size = 64,
                                   shuffle = True)
  
  initial_dataloader_test = DataLoader (test_dataset, batch_size = 64,
                                   shuffle = True)
  #First move the train data set
  for image, label in initial_dataloader_train:
    image, label = image.to ('cuda'), label.to ('cuda')
    train_list.append (torch.utils.data.TensorDataset (image, label))

  #Then move the test (validation) dataset
  
  for image, label in initial_dataloader_test:
    image, label = image.to ('cuda'), label.to ('cuda')
    test_list.append (torch.utils.data.TensorDataset (image, label))

 gpu_train_dataset = torch.utils.data.ConcatDataset ((train_list))
 gpu_test_dataset = torch.utils.data.ConcatDataset ((test_list))


  #use autoencoder

  encoder = Encoder ()

  decoder = Decoder ()

  encoder.to (device)

  decoder.to (device)

  loss_fn = nn.BCEWithLogitsLoss ()

  params_to_optimize = [{'params': encoder.parameters ()}, {'params':decoder.parameters ()}]

  optimizer = torch.optim.Adam (params_to_optimize, lr = 0.001)

  save_best_model = SaveBestModel ()

  dataloader_train = DataLoader (gpu_train_dataset, batch_size = len (train_indices), 
                                            shuffle = True, pin_memory = False, 
                                            num_workers = 0)
  dataloader_test = DataLoader (gpu_test_dataset, batch_size = len (test_indices), 
                                            shuffle = True, pin_memory = False, 
                                            num_workers = 0)

  num_epochs = 250

  for epoch in range (num_epochs):

    loss_autoencoder, encoder, decoder, optimizer = train_autoencoder (encoder = encoder, decoder = decoder, 
                                          dataloader = dataloader_train, loss_fn = loss_fn, 
                                          optimizer = optimizer)
    
  save_best_model (loss_autoencoder, 
        epoch, encoder, optimizer)
    

  model_encoder = Encoder () 

  optimizer = torch.optim.Adam (model_encoder.parameters (), lr= 0.001)

  checkpoint = torch.load('checkpoint.pth')

model_encoder.load_state_dict(checkpoint['model_state_dict'])

  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

However, in the last line, an error is raised since the optimizer cannot be applied for the encoder part only. My question is what the best way is for splitting the autoencoder optimizer into the encoder and decoder parts?

Best Regards

You could create two separate optimizers, one for each model, or alternatively you could try to create the state_dict for the encoder optimizer from the saved checkpoint.
To do so, create a fake optimizer for the encoder only, check its state_dict, and try to recreate the same dict structure from the checkpoint.

1 Like

Thank you very much for your quick response.

I followed the first instruction which solved the problem.

Regarding the second method, I retrieved the following information for the encoder optimizer and the total autoencoder optimizer:

encoder_optimizer.state_dict () ['state'].keys () = dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
total_autoencoder_optimizer.state_dict () ['state'].keys () = dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27])

encoder_optimizer.state_dict () ['param_groups'] = [{'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0,
  'amsgrad': False,
  'maximize': False,
  'foreach': None,
  'capturable': False,
  'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}]
total_autoencoder_optimizer.state_dict () ['param_groups'] = [{'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0,
  'amsgrad': False,
  'maximize': False,
  'foreach': None,
  'capturable': False,
  'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]},
 {'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0,
  'amsgrad': False,
  'maximize': False,
  'foreach': None,
  'capturable': False,
  'params': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}]

Then, the new “state” and “param_groups” for the encoder optimizer could be created this way:

optimizer_encoder_state = dict ((k, optimizer.state_dict () ['state'] [k]) for k in range (14))

  optimizer_encoder_param_groups = optimizer.state_dict () ['param_groups'] [0]

What would be the best way for creating the new optimizer?

Actually, I think that the following approach would not be the answer:

#starting the new encoder model and optimizer

model_encoder = Encoder ().to (device)

  best_optimizer = torch.optim.Adam (params_to_optimize, lr= 0.001)

  checkpoint = torch.load('checkpoint.pth')

  model_encoder.load_state_dict(checkpoint['encoder_state_dict'])

  best_optimizer.load_state_dict (checkpoint ['optimizer_state_dict'])

  optimizer_encoder_state = dict ((k, best_optimizer.state_dict () ['state'] [k]) for k in range (14))

  optimizer_encoder_param_groups = best_optimizer.state_dict () ['param_groups'] [0]

   
  optimizer_encoder = torch.optim.Adam (model_encoder.parameters (), lr = 0.001)

optimizer_encoder.state_dict () ['param_groups'] = optimizer_encoder_param_groups

optimizer_encoder.state_dict () ['state'] = optimizer_encoder_state

Note: I modified the SaveBestModel class as follows so that I can access the encoder.state_dict () more easily:

#Creating a function for saving the best model weights
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_loss=float('inf')
    ):
        self.best_loss = best_loss
        
    def __call__(
        self, current_loss, 
        epoch, encoder, decoder, optimizer
    ):
        if current_loss < self.best_loss:
            self.best_loss = current_loss
            torch.save({
                'epoch': epoch+1,
                'encoder_state_dict': encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict (),
                'optimizer_state_dict' : optimizer.state_dict (),
                'loss': current_loss,
                }, 'checkpoint.pth')