import torch.nn as nn
import torch
class Foo(nn.Module):
def __init__(self, input_size=20, output_size=100):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.fc0 = nn.Linear(self.input_size, 30)
self.fc1 = nn.Linear(30, 40)
self.fc_out = nn.Linear(40, self.output_size)
def increment_output_size(self, copy_idx: int):
old_output_size = self.output_size
old_fc_out = self.fc_out
self.output_size += 1
self.fc_out = nn.Linear(40, self.output_size)
with torch.no_grad():
self.fc_out.weight.data[:old_output_size] = old_fc_out.weight.data
self.fc_out.weight.data[-1] = old_fc_out.weight.data[copy_idx].clone()
self.fc_out.bias.data[:old_output_size] = old_fc_out.bias.data
self.fc_out.bias.data[-1] = old_fc_out.bias.data[copy_idx].clone()
if __name__ == "__main__":
# SETUP MODEL AND OPTIMIZER
model = Foo()
optimizer = torch.optim.Adam(model.parameters())
# DO SOME TRAINING HERE (Adam optimizer will hold state stat for each parameter)
# ...
# MODIFY MODEL
model.increment_output_size(copy_idx=25)
# UPDATE OPTIMIZER
# 1. For parameters (weights/bias) of model.fc0 and model.fc1, the state should be retained.
# 2. For parameter fc_out, which is modified:
# a. for model.fc_out.weight/bias[:100], preserve the corresponding states in Adam
# b. for model.fc_out.weight/bias[101], clone the state corresponding to model.fc_out.weight[25]
# DO SOME MORE TRAINING HERE
# ...
print("Done")
The code above attempts to train a model that, throughout the training process, may dynamically modify its self.fc_out module to accommodate for increase in the number of classes. I want it so that the existing Adam optimizer can adaptively update its internal state in accordance to the model modification. This means both updating the relevant items in self.updater.optimizer.param_groups[0]["params"] and self.updater.optimizer.state (and other stuff, if necessary).
How can I do this? Specifically:
How can I index the optimizer fields/keys relevant to the self.fc_out?
How do I replace the relevant optimizer params and states? Is it sufficient to do self.updater.optimizer.param_groups[0]["params"][index_of_fc_out_weight] = model.fc_out.weight and self.updater.optimizer.state[index_of_fc_out_weight] = new_fc_out_state?
Are there any other fields or under-the-hood mechanisms that I need to be aware of?
train a model that, throughout the training process, may dynamically modify its self.fc_out module to accommodate for increase in the number of classes.
Why would you want to do this? I think that intuition may help better solve the problem.
I am working in a scenario where, overtime, the number of classes will increase. As I have no way of knowing beforehand how many classes there are, I want to make the classification component dynamically expandable.
Furthermore, I suspect (for research purposes) that copying weights from an existing class related to the new class may improve training (i.e. the new class is a few-shot problem due to initial scarcity of training samples). Thus, I also want to mirror the Adam states to reflect changes in the output parameter (not sure if this works, but I intend to try it out).
I hope this intuition makes sense, and helps with this problem.
This seems like quite an interesting problem. Could not find an easy to use end-to-end example.
Main optimiser code;
import torch
def update_optimizer_state(optimizer, old_fc_out, new_fc_out, copy_idx, old_output_size):
# Identify the indices corresponding to the old `fc_out` parameters
fc_out_weight_idx = next(i for i, p in enumerate(optimizer.param_groups[0]["params"]) if p is old_fc_out.weight)
fc_out_bias_idx = next(i for i, p in enumerate(optimizer.param_groups[0]["params"]) if p is old_fc_out.bias)
# Replace the old parameters with the new ones
optimizer.param_groups[0]["params"][fc_out_weight_idx] = new_fc_out.weight
optimizer.param_groups[0]["params"][fc_out_bias_idx] = new_fc_out.bias
# Initialize new optimizer states
new_state_weight = {
'exp_avg': torch.zeros_like(new_fc_out.weight.data),
'exp_avg_sq': torch.zeros_like(new_fc_out.weight.data),
'step': torch.tensor(0, dtype=torch.int64)
}
new_state_bias = {
'exp_avg': torch.zeros_like(new_fc_out.bias.data),
'exp_avg_sq': torch.zeros_like(new_fc_out.bias.data),
'step': torch.tensor(0, dtype=torch.int64)
}
if old_fc_out.weight in optimizer.state:
old_state_weight = optimizer.state.pop(old_fc_out.weight)
new_state_weight['exp_avg'][:old_output_size] = old_state_weight['exp_avg']
new_state_weight['exp_avg_sq'][:old_output_size] = old_state_weight['exp_avg_sq']
new_state_weight['exp_avg'][-1] = old_state_weight['exp_avg'][copy_idx].clone()
new_state_weight['exp_avg_sq'][-1] = old_state_weight['exp_avg_sq'][copy_idx].clone()
if 'step' in old_state_weight:
new_state_weight['step'] = old_state_weight['step']
if old_fc_out.bias in optimizer.state:
old_state_bias = optimizer.state.pop(old_fc_out.bias)
new_state_bias['exp_avg'][:old_output_size] = old_state_bias['exp_avg']
new_state_bias['exp_avg_sq'][:old_output_size] = old_state_bias['exp_avg_sq']
new_state_bias['exp_avg'][-1] = old_state_bias['exp_avg'][copy_idx].clone()
new_state_bias['exp_avg_sq'][-1] = old_state_bias['exp_avg_sq'][copy_idx].clone()
if 'step' in old_state_bias:
new_state_bias['step'] = old_state_bias['step']
# Reassign the new states to the optimizer
optimizer.state[new_fc_out.weight] = new_state_weight
optimizer.state[new_fc_out.bias] = new_state_bias
# Force the optimizer to re-reference the new params
optimizer.param_groups = optimizer.param_groups
I have made a repo for basic testing and implementation.
Do raise an issue if something is buggy! Most code is commented, but happy to explain anything.