Adding neurons to optimizer state

I’ve written the following class which allows me to add neurons to hidden layers:

import torch
import torch.nn as nn

# Classic DQN. Increase_capacity method adds new nodes to layers according to increment
# TODO: decrease capacity does not work as of yet

class DQN(nn.Module):
    def __init__(self, num_inputs, hidden, num_actions, non_linearity):
        super(DQN, self).__init__()
        
        self.num_inputs = num_inputs
        self.hidden = hidden
        self.num_actions = num_actions
        self.non_linearity = non_linearity
        
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(num_inputs, self.hidden[0]))
        
        previous = self.hidden[0]
        for hidden_layer_size in self.hidden[1:]:
            self.layers.append(nn.Linear(previous, hidden_layer_size))
            previous = hidden_layer_size
            
        self.layers.append(nn.Linear(previous, num_actions))        
        
    def forward(self, x):
        for i in range(len(self.layers) - 1):
            x = self.non_linearity(self.layers[i](x))
        return self.layers[-1](x)
    
    def increase_capacity(self, increment):
        for i in range(len(self.hidden)):
            self.hidden[i] += increment[i]
        
        bias = self.layers[0].bias.data
        weight = self.layers[0].weight.data
        self.layers[0] = nn.Linear(self.num_inputs, self.hidden[0])
        if increment[0]>0:
            self.layers[0].weight.data[0:-increment[0],:] = weight
            self.layers[0].bias.data[0:-increment[0]] = bias
        else:
            self.layers[0].weight.data[0:,:] = weight
            self.layers[0].weight.data = bias
        
        for i in range(1, len(self.layers) - 1):
            bias = self.layers[i].bias.data
            weight = self.layers[i].weight.data
            self.layers[i] = nn.Linear(self.hidden[i-1], self.hidden[i])
            if increment[i] > 0:
                if increment[i-1] >0:
                    self.layers[i].bias.data[0:-increment[i]] = bias
                    self.layers[i].weight.data[0:-increment[i],0:-increment[i-1]] = weight
                else:
                    self.layers[i].bias.data[0:-increment[i]] = bias
                    self.layers[i].weight.data[0:-increment[i],0:] = weight
            else:
                if increment[i-1] >0:
                    self.layers[i].bias.data = bias
                    self.layers[i].weight.data[0:,0:-increment[i-1]] = weight
                else:
                    self.layers[i].bias.data = bias
                    self.layers[i].weight.data[0:,0:] = weight
        
        bias = self.layers[-1].bias.data
        weight = self.layers[-1].weight.data
        self.layers[-1] = nn.Linear(self.hidden[-1], self.num_actions)
        if increment[-1] >0:
            self.layers[-1].bias.data = bias
            self.layers[-1].weight.data[:,0:-increment[-1]] = weight
        else:
            self.layers[-1].bias.data = bias
            self.layers[-1].weight.data[:,0:] = weight
    
    def act(self, state, epsilon, mask):
        if np.random.rand() > epsilon:
            state = torch.tensor([state], dtype=torch.float32, device=device)
            mask = torch.tensor([mask], dtype=torch.float32, device=device)
            q_values = self.forward(state) + mask
            action = q_values.max(1)[1].view(1, 1).item()
        else:
            action =  np.random.randint(self.num_actions)
        return action

Now I’ve written a little sanity check (whether it leads to sanity is questionable at this point): a network with 2 layer with both 1 neuron should fail to learn the x-or function, whereas a network where 4 neurons have been added should. If I initialise a new optimiser this indeed works. The optimiser I use is Adam, which keeps track of learning-rates per parameter. I’d like to keep the learning-rates of Adam for the weights and biases that already existed before I add additional neurons. The following is my failed attempt to doing so:

import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

# Credits to Alvations
def generate_zero():
    return random.uniform(0, 49) / 100

def generate_one():
    return random.uniform(50, 100) / 100


def generate_xor_XY(num_data_points):
    Xs, Ys = [], []
    for _ in range(num_data_points):
        # xor(0, 0) -> 0 
        Xs.append([generate_zero(), generate_zero()]); Ys.append([0])
        # xor(1, 0) -> 1
        Xs.append([generate_one(), generate_zero()]); Ys.append([1])
        # xor(0, 1) -> 1
        Xs.append([generate_zero(), generate_one()]); Ys.append([1])
        # xor(1, 1) -> 0
        Xs.append([generate_one(), generate_one()]); Ys.append([0])
    return Xs, Ys

# Initialisation
network = DQN(2,[1,1],1,F.relu)
# optimizer = optim.Adam(network.parameters(), amsgrad=False)
optimizer = optim.Adam(network.parameters(), amsgrad=True)
criterion = nn.MSELoss()

# Train 50000 steps to show 1 neuron cannot solve x-or task
for i in range(50000):
    optimizer.zero_grad()
    
    Xs, Ys = generate_xor_XY(1)
    Xs = torch.tensor(Xs)
    Ys = torch.tensor(Ys, dtype=torch.float)
    
    prediction = network(Xs)
    loss = criterion(prediction, Ys)
    
    loss.backward()
    optimizer.step()
    
print(network(torch.tensor([[1,0],[0,1],[1,1],[0,0]], dtype=torch.float)))
print(loss)

# Add 5 neurons to first layer
capacity = [4,4]
network.increase_capacity(capacity)

# Uncomment the following line and comment the lines following it for normal initialisation.
# optimizer = optim.Adam(network.parameters(), amsgrad=True)

nw_param = [p for p in network.parameters()]
new_param_group = []

layer_idx = 0
for idx, group in enumerate(optimizer.param_groups):
        for idx_p, p in enumerate(group['params']):
            # Save previous information
            prev_grad = p.grad
            old_p = copy.deepcopy(p)
            old_state = copy.copy(optimizer.state[p])
            old_step = old_state['step']
            old_exp_avg = old_state['exp_avg']
            old_exp_avg_sq = old_state['exp_avg_sq']
            old_max_exp_avg_sq = old_state['max_exp_avg_sq']

            # Remove old parameter from state
            optimizer.state.pop(p)
            
            # Weights
            if p.dim()>1:
                p = nn.Parameter(nw_param[layer_idx])
                p.grad = torch.zeros_like(p)
                new_exp_avg = torch.torch.zeros_like(p)
                new_exp_avg_sq = torch.torch.zeros_like(p)
                new_max_exp_avg_sq = torch.torch.zeros_like(p)
                p.grad[0:prev_grad.size(0),0:prev_grad.size(1)] = prev_grad                                                           
                
                optimizer.state[p]['step'] = old_step
                optimizer.state[p]['exp_avg'] = new_exp_avg
                optimizer.state[p]['exp_avg'][0:prev_grad.size(0),0:prev_grad.size(1)] = old_exp_avg
                optimizer.state[p]['exp_avg_sq'] = new_exp_avg_sq
                optimizer.state[p]['exp_avg_sq'][0:prev_grad.size(0),0:prev_grad.size(1)] = old_exp_avg_sq
                optimizer.state[p]['max_exp_avg_sq'] = new_max_exp_avg_sq
                optimizer.state[p]['max_exp_avg_sq'][0:prev_grad.size(0),0:prev_grad.size(1)] = old_max_exp_avg_sq
                new_param_group.append(p)
                
            # Biases
            else:
                p = nn.Parameter(nw_param[layer_idx])
                p.grad = torch.zeros_like(p)
                new_exp_avg = torch.zeros_like(p)
                new_exp_avg_sq = torch.zeros_like(p)
                new_max_exp_avg_sq = torch.zeros_like(p)
                p.grad[0:prev_grad.size(0)] = prev_grad
                
                optimizer.state[p]['step'] = old_step
                optimizer.state[p]['exp_avg'] = new_exp_avg
                optimizer.state[p]['exp_avg'][0:prev_grad.size(0)] = old_exp_avg
                optimizer.state[p]['exp_avg_sq'] = new_exp_avg_sq
                optimizer.state[p]['exp_avg_sq'][0:prev_grad.size(0)] = old_exp_avg_sq
                optimizer.state[p]['max_exp_avg_sq'] = new_max_exp_avg_sq
                optimizer.state[p]['max_exp_avg_sq'][0:prev_grad.size(0)] = old_max_exp_avg_sq
                new_param_group.append(p)
                
            layer_idx += 1
            
optimizer.param_groups[0]['params'] = new_param_group

print(network)

# Train 50000 steps to show by adding neurons the task can be solved
for i in range(50000):
    optimizer.zero_grad()
    
    Xs, Ys = generate_xor_XY(1)
    Xs = torch.tensor(Xs)
    Ys = torch.tensor(Ys, dtype=torch.float)
    
    prediction = network(Xs)
    loss = criterion(prediction, Ys)
    
    loss.backward()
    optimizer.step()
    
print(network(torch.tensor([[1,0],[0,1],[1,1],[0,0]], dtype=torch.float)))
print(loss)

I’m trying to get the same optimizer state, but with additional parameters for the added neurons. This seems like a convoluted way of doing it (and it doesn’t work:p). Does anyone know of an (easier) way to do this or see where I’m going wrong?