Possible to add/initialize new nodes to hidden layer partway through training?

Hi I’m very new to using PyTorch and am still wrapping my head around it all, but I was wondering given the dynamic nature of pytorch how it might be possible to add new neurons (with different parameters than the other nodes potentially) to the hidden layer partway through training. Or if I train first, add new neurons, train again, this detail is not so important… Is it as simple as modifying the tensors containing the weights?
Or would it be best to allocate the layer with all potential nodes, with some being ‘silenced’, and make them available through training?

you can keep your “neurons” in a ParameterList, and keep adding new Variable or nn.Parameter of neurons to that list, whether that be each individual neurons (not super efficient), or say a block of 4096 neurons whenever you want a new set.

You can also take the “some are silenced” approach, and that’s not a bad idea either.

2 Likes

A bit of code showing how this should be done would be extremely helpful. I’m trying to add neurons to my initial layer and update the model as new data comes in. Is it possible to initialize the model from a ParameterList which will then update with the new layer dimensions as you append new neurons?

This is a simple working example of how I’ve been doing it, not exactly as smth described, but it works for me
. In your case, you would only have one weight matrix to update (since it’s the input layer).

In short:

  1. initialize the new weights for the units I’m adding
  2. concatenate to a copy of the current weight matrices
  3. adjust the sizes of the model parameters such that they match the size of the new matrices in Step 2.
  4. set the weight values to the values in Step 2.
class Model(nn.Module):
    def __init__(self, layer_size, hidden, input_size, output_size):
        super(Model, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.layer_size = layer_size
        self.relu = nn.ReLU()


        # initialize weights
        self.fcs = nn.ModuleList([nn.Linear(self.input_size, self.layer_size)])
        self.fcs.append(nn.Linear(self.layer_size, self.output_size))

    def forward(self, x):
       # Your typical forward pass goes here


    def add_units(self, n_new):
        # take a copy of the current weights stored in self.fcs
        current = [ix.weight.data for ix in self.fcs]

        # make the new weights in and out of hidden layer you are adding neurons to
        hl_input = torch.zeros([n_new, current[0].shape[1]])
        nn.init.xavier_uniform_(hl_input, gain=nn.init.calculate_gain('relu'))
        hl_output = torch.zeros([current[1].shape[0], n_new])
        nn.init.xavier_uniform_(hl_input, gain=nn.init.calculate_gain('relu'))

        # concatenate the old weights with the new weights
        new_wi = torch.cat([current[0], hl_input], dim=0)
        new_wo = torch.cat([current[1], hl_output], dim=1)

        # reset weight and grad variables to new size
        self.fcs[0] = nn.Linear(current[0].shape[1], self.layer_size)
        self.fcs[1] = nn.Linear(self.layer_size, current[1].shape[0])

        # set the weight data to new values
        self.fcs[0].weight.data = torch.tensor(new_wi, requires_grad=True, device=self.device)
        self.fcs[1].weight.data = torch.tensor(new_wo, requires_grad=True, device=self.device)
2 Likes

This gives me an error where my grad is None and I can’t figure out how to fix it. Any advice?

The “some are silent” approach seems to be convenient if you want a new head in your network. Thanks!