Possible to add/initialize new nodes to hidden layer partway through training?

This is a simple working example of how I’ve been doing it, not exactly as smth described, but it works for me
. In your case, you would only have one weight matrix to update (since it’s the input layer).

In short:

  1. initialize the new weights for the units I’m adding
  2. concatenate to a copy of the current weight matrices
  3. adjust the sizes of the model parameters such that they match the size of the new matrices in Step 2.
  4. set the weight values to the values in Step 2.
class Model(nn.Module):
    def __init__(self, layer_size, hidden, input_size, output_size):
        super(Model, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.layer_size = layer_size
        self.relu = nn.ReLU()


        # initialize weights
        self.fcs = nn.ModuleList([nn.Linear(self.input_size, self.layer_size)])
        self.fcs.append(nn.Linear(self.layer_size, self.output_size))

    def forward(self, x):
       # Your typical forward pass goes here


    def add_units(self, n_new):
        # take a copy of the current weights stored in self.fcs
        current = [ix.weight.data for ix in self.fcs]

        # make the new weights in and out of hidden layer you are adding neurons to
        hl_input = torch.zeros([n_new, current[0].shape[1]])
        nn.init.xavier_uniform_(hl_input, gain=nn.init.calculate_gain('relu'))
        hl_output = torch.zeros([current[1].shape[0], n_new])
        nn.init.xavier_uniform_(hl_input, gain=nn.init.calculate_gain('relu'))

        # concatenate the old weights with the new weights
        new_wi = torch.cat([current[0], hl_input], dim=0)
        new_wo = torch.cat([current[1], hl_output], dim=1)

        # reset weight and grad variables to new size
        self.fcs[0] = nn.Linear(current[0].shape[1], self.layer_size)
        self.fcs[1] = nn.Linear(self.layer_size, current[1].shape[0])

        # set the weight data to new values
        self.fcs[0].weight.data = torch.tensor(new_wi, requires_grad=True, device=self.device)
        self.fcs[1].weight.data = torch.tensor(new_wo, requires_grad=True, device=self.device)
2 Likes