Delete nodes (or layers) in MLP

I have trained a simple MLP with an input layer, one hidden layer, and an output layer of 18 nodes. I would like to be able to selectively delete 17 output nodes and their associated weights/biases, leaving just one output node. I am seeking either a way to completely remove 17 output nodes, or a possible alternative of deleting the final layer and replacing it with a 1 node output layer corresponding to the 1 node that would have been leftover after “pruning”.

Below is my model definition:
class NMR_Model(nn.Module):
def init(self):
super().init()
self.lin1 = nn.Linear(14000, 200)
self.relu1 = nn.ReLU()
self.lin2 = nn.Linear(200, 18)
def forward(self, input):
return (self.lin2(self.relu1(self.lin1(input))))

I have been able to change 17 out of the18 weights and biases to 0 using the following code, but for my purposes I need them fully removed:
p = prune.L1Unstructured(amount=17/18)
pruned = p.prune(model2.lin2.bias)
pruned2 = p.prune(model2.lin2.weight)
model2.lin2.bias = torch.nn.Parameter(pruned)
model2.lin2.weight = torch.nn.Parameter(pruned2)

I think I may be on the right track with deleting the final layer and appending a new layer of only 1 node. Below is the code I used to delete the final layer, as well as to define the weights and bias I want to append to the end of the MLP:

newmodel = torch.nn.Sequential(*(list(model.children())[:-1]))
w1 = model.lin2.weight[0]
b1 = model.lin2.bias[0]

Now if I can determine how to add the “w1” and “b1” tensors to the end of “newmodel” then I believe that will resolve this issue.

I believe I have found a solution by defining a second model with 1 output node, and then specifying the parameters from the first trained model:

class NMR_Model2(nn.Module):
def init(self):
super().init()
self.lin1 = nn.Linear(14000, 200)
self.relu1 = nn.ReLU()
self.lin2 = nn.Linear(200, 1)
def forward(self, input):
return (self.lin2(self.relu1(self.lin1(input))))

model2 = NMR_Model2()
model2.lin1.weight.data = model.lin1.weight.data
model2.lin1.bias.data = model.lin1.bias.data
model2.lin2.weight.data = model.lin2.weight.data[0]
model2.lin2.bias.data = model.lin2.bias.data[0]