MNIST Incremental Learning

I want to train iteratively, first training on 0 and 1, but then pause training, modify the output layer to include a third output and then train on 0,1 AND 2 by adding 2s back in my dataset. I do not have to train from scratch. Train 10 epochs every time I add in a new digit. Just Be sure to train using the same model that was already trained with one less digit with the augmented output layer (the weights should not be reset only the output layer is increased by 1). After training for 10 epochs, add in 3 digits and expand the output layer of the network again. This has to be done all the way up until I have eventually done training on the entire dataset.
How to achieve this on the below code?

Code
@ptrblck_de

For the MNIST dataset, you could try to either manipulate the underlying data and target attributes or write a custom Dataset using the MNIST dataset as the parent class and implement your sampling logic e.g. in the forward method by passing a specific class flag.
I would prefer the latter case, as it sounds cleaner.

To add a new neuron to the last layer, I would recommend to use the functional API, store the parameters e.g. in a nn.ParameterList, and add the new parameter to this list.
Alternatively, you could also manipulate the last linear’s parameters also by reassigning the new weight parameter and copy the old values in.

I’m not sure how you would like to deal with the optimizer. Would you like to reinitialize it after each change in the model?

Thank you for giving a head start to the solution. Is there any source/link which I can look up to?
Also, I don’t want to reinitialize the optimizer after each change in the model.

Here is an example of directly manipulating the layer.
Note that you shouldn’t use the .data attribute, but instead wrap the manipulation in a with torch.no_grad() block.

Alternatively, here is an example of the nn.ParameterList usage:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.params = nn.ParameterList()
        self.params.append(nn.Parameter(torch.randn(2, 10)))
        
    def forward(self, x):
        w = torch.cat(tuple(self.params.parameters()), 0)
        x = F.linear(x, w)
        return x
    

model = MyModel()
x = torch.randn(1, 10)
out = model(x)
print(out.shape)

# Add param
new_param = nn.Parameter(torch.randn(1, 10))
model.params.append(new_param)

out = model(x)
print(out.shape)

In that case you could use optimizer.add_param_group with your new parameter.