Updating Adam Optimizer After Modifying Model Architecture

import torch.nn as nn
import torch

class Foo(nn.Module):
    def __init__(self, input_size=20, output_size=100):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size

        self.fc0 = nn.Linear(self.input_size, 30)
        self.fc1 = nn.Linear(30, 40)
        self.fc_out = nn.Linear(40, self.output_size)

    def increment_output_size(self, copy_idx: int):
        old_output_size = self.output_size
        old_fc_out = self.fc_out

        self.output_size += 1
        self.fc_out = nn.Linear(40, self.output_size)
        with torch.no_grad():
            self.fc_out.weight.data[:old_output_size] = old_fc_out.weight.data
            self.fc_out.weight.data[-1] = old_fc_out.weight.data[copy_idx].clone()
            self.fc_out.bias.data[:old_output_size] = old_fc_out.bias.data
            self.fc_out.bias.data[-1] = old_fc_out.bias.data[copy_idx].clone()


if __name__ == "__main__":
    # SETUP MODEL AND OPTIMIZER
    model = Foo()
    optimizer = torch.optim.Adam(model.parameters())

    # DO SOME TRAINING HERE (Adam optimizer will hold state stat for each parameter)
    # ...

    # MODIFY MODEL
    model.increment_output_size(copy_idx=25)

    # UPDATE OPTIMIZER
    #  1. For parameters (weights/bias) of model.fc0 and model.fc1, the state should be retained.
    #  2. For parameter fc_out, which is modified:
    #     a. for model.fc_out.weight/bias[:100], preserve the corresponding states in Adam
    #     b. for model.fc_out.weight/bias[101], clone the state corresponding to model.fc_out.weight[25]

    # DO SOME MORE TRAINING HERE
    # ...

    print("Done")

The code above attempts to train a model that, throughout the training process, may dynamically modify its self.fc_out module to accommodate for increase in the number of classes. I want it so that the existing Adam optimizer can adaptively update its internal state in accordance to the model modification. This means both updating the relevant items in self.updater.optimizer.param_groups[0]["params"] and self.updater.optimizer.state (and other stuff, if necessary).

How can I do this? Specifically:

  1. How can I index the optimizer fields/keys relevant to the self.fc_out?
  2. How do I replace the relevant optimizer params and states? Is it sufficient to do self.updater.optimizer.param_groups[0]["params"][index_of_fc_out_weight] = model.fc_out.weight and self.updater.optimizer.state[index_of_fc_out_weight] = new_fc_out_state?
  3. Are there any other fields or under-the-hood mechanisms that I need to be aware of?

Any help is appreciated.

Again, any help is appreciated.

train a model that, throughout the training process, may dynamically modify its self.fc_out module to accommodate for increase in the number of classes.

Why would you want to do this? I think that intuition may help better solve the problem.

I am working in a scenario where, overtime, the number of classes will increase. As I have no way of knowing beforehand how many classes there are, I want to make the classification component dynamically expandable.

Furthermore, I suspect (for research purposes) that copying weights from an existing class related to the new class may improve training (i.e. the new class is a few-shot problem due to initial scarcity of training samples). Thus, I also want to mirror the Adam states to reflect changes in the output parameter (not sure if this works, but I intend to try it out).

I hope this intuition makes sense, and helps with this problem.

This seems like quite an interesting problem. Could not find an easy to use end-to-end example.

Main optimiser code;

import torch

def update_optimizer_state(optimizer, old_fc_out, new_fc_out, copy_idx, old_output_size):
    # Identify the indices corresponding to the old `fc_out` parameters
    fc_out_weight_idx = next(i for i, p in enumerate(optimizer.param_groups[0]["params"]) if p is old_fc_out.weight)
    fc_out_bias_idx = next(i for i, p in enumerate(optimizer.param_groups[0]["params"]) if p is old_fc_out.bias)

    # Replace the old parameters with the new ones
    optimizer.param_groups[0]["params"][fc_out_weight_idx] = new_fc_out.weight
    optimizer.param_groups[0]["params"][fc_out_bias_idx] = new_fc_out.bias

    # Initialize new optimizer states
    new_state_weight = {
        'exp_avg': torch.zeros_like(new_fc_out.weight.data),
        'exp_avg_sq': torch.zeros_like(new_fc_out.weight.data),
        'step': torch.tensor(0, dtype=torch.int64)
    }

    new_state_bias = {
        'exp_avg': torch.zeros_like(new_fc_out.bias.data),
        'exp_avg_sq': torch.zeros_like(new_fc_out.bias.data),
        'step': torch.tensor(0, dtype=torch.int64)
    }

    if old_fc_out.weight in optimizer.state:
        old_state_weight = optimizer.state.pop(old_fc_out.weight)
        new_state_weight['exp_avg'][:old_output_size] = old_state_weight['exp_avg']
        new_state_weight['exp_avg_sq'][:old_output_size] = old_state_weight['exp_avg_sq']
        new_state_weight['exp_avg'][-1] = old_state_weight['exp_avg'][copy_idx].clone()
        new_state_weight['exp_avg_sq'][-1] = old_state_weight['exp_avg_sq'][copy_idx].clone()
        if 'step' in old_state_weight:
            new_state_weight['step'] = old_state_weight['step']

    if old_fc_out.bias in optimizer.state:
        old_state_bias = optimizer.state.pop(old_fc_out.bias)
        new_state_bias['exp_avg'][:old_output_size] = old_state_bias['exp_avg']
        new_state_bias['exp_avg_sq'][:old_output_size] = old_state_bias['exp_avg_sq']
        new_state_bias['exp_avg'][-1] = old_state_bias['exp_avg'][copy_idx].clone()
        new_state_bias['exp_avg_sq'][-1] = old_state_bias['exp_avg_sq'][copy_idx].clone()
        if 'step' in old_state_bias:
            new_state_bias['step'] = old_state_bias['step']

    # Reassign the new states to the optimizer
    optimizer.state[new_fc_out.weight] = new_state_weight
    optimizer.state[new_fc_out.bias] = new_state_bias

    # Force the optimizer to re-reference the new params
    optimizer.param_groups = optimizer.param_groups

I have made a repo for basic testing and implementation.

Do raise an issue if something is buggy! Most code is commented, but happy to explain anything.