Keep part of model


(Kong) #1

how can i keep part of my model ?

model = torchfcn.models.FCN32s(n_class=21)
checkpoint = torch.load("./logs/MODEL-fcn32s/checkpoint.pth.tar")
model.load_state_dict(checkpoint['model_state_dict'])
model = model[0:10] # to keep until 10th layer

What about the forward pass ? Do I have to redefine it ?


#2

It depends on the model.
If you are slicing a nn.Sequential model, you can just keep your desired layers and wrap them in a new nn.Sequential instance:

model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
)

model = nn.Sequential(*list(model.children())[:2])

x = torch.randn(1, 10)
output = model(x)

In case of a custom model with an own forward, you would need to rewrite this method.