Hello! I have a trained feed forward NN with a given number of inputs and I want to remove all the weights associated to one of the inputs (including the input node itself). Basically I want to keep everything, but have an input 1 dimension lower than before. What is the most effective way to do so? Thank you!
You could slice the weight parameter and reassign it to the
E.g. if you would like to keep the first 9 input features out of 10:
lin = nn.Linear(10, 5) with torch.no_grad(): lin.weight = nn.Parameter(lin.weight[:, :9]) x = torch.randn(2, 9) output = lin(x)