Is there a pytorch API to remove a layer from forward pass?

Hello all,
I have been trying to remove a layer completely from both forward and backward pass while training the model.
Say below is the model.

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x)

Is there any way i can precompute the layer output , load it into a train loader object and use it in future epochs while completely getting rid of training that particular initial layer.
We can use the detach() method in order to prevent backprop but forward pass still computes the forward pass each time even though the weights parameters are fixed which seems redundant. I will really appreciate if there is a hack of any sort that can help make this possible.

4 Likes

Hi @Sohil_Newa,
I know this is an old topic, so apologies. But were you able to find any hack for this?
Thanks & Regards

If you want to remove a particular layer, you could replace it with nn.Identity.
I’m not sure what the use case from the original post is, as it seems the output should be precomputed somehow.

Hi @ptrblck,

The use case is to create conditional computation of layers, then define multiple entry points to a network and compute layers only corresponding to the specific entry point. I was trying to apply this to progressive training of CNNs for classification/segmentation, applying the principle from Progressive GANs. The advantage could be that this method could generalize to different standard network architectures.

So what I am looking for is below:

new_resnet18 = MagicConditionalModifier(resnet18, stages=['layer1', 'layer2'])
stage_1_input #torch.tensor(Bx64xHxW)
output = new_resnet18(stage_1_input, stage=1) #or stage='layer1')
stage_2_input #torch.tensor(Bx128xHxW)
output = new_resnet18(stage_2_input, stage=2)#or stage='layer2')

We can of course create a new network from the old network and create the conditional computational paths, but it seemed to me that there had to be a better way. (Maybe I’m overthinking this!)