model = torchfcn.models.FCN32s(n_class=21)
checkpoint = torch.load("./logs/MODEL-fcn32s/checkpoint.pth.tar")
model.load_state_dict(checkpoint['model_state_dict'])
model = model[0:10] # to keep until 10th layer
What about the forward pass ? Do I have to redefine it ?
It depends on the model.
If you are slicing a nn.Sequential model, you can just keep your desired layers and wrap them in a new nn.Sequential instance:
model = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU(),
)
model = nn.Sequential(*list(model.children())[:2])
x = torch.randn(1, 10)
output = model(x)
In case of a custom model with an own forward, you would need to rewrite this method.