Save part of the model

Hi, I’m new in pytorch…
how can I save only part of the model?
I train model that for training has 3 output but for inference, I just need one of the outputs
can I load the model and save just the part I need?
that would save time in the inference

has some have an example?
Thanks

I would save the whole model’s state_dict and just reimplement an “inference” model, which yields only one output. Here is a small example:


class MyModelA(nn.Module):
    def __init__(self):
        super(MyModelA, self).__init__()
        self.fc1 = nn.Linear(10, 1)
        self.fc2 = nn.Linear(10, 1)
        self.fc3 = nn.Linear(10, 1)
        
    def forward(self, x):
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x3 = self.fc3(x)
        return x1, x2, x3
    

# Train modelA
modelA = MyModelA()
x = torch.randn(1, 10)
output1, output2, output3 = modelA(x)
# ...

# Save modelA
torch.save(modelA.state_dict(), 'modelA.pth')


# Duplicate modelA and add a switch for inference
class MyModelB(nn.Module):
    def __init__(self, fast_inference=False):
        super(MyModelB, self).__init__()
        self.fc1 = nn.Linear(10, 1)
        self.fc2 = nn.Linear(10, 1)
        self.fc3 = nn.Linear(10, 1)
        self.fast_inference = fast_inference
        
    def forward(self, x):
        if self.fast_inference:
            return self.fc1(x)
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x3 = self.fc3(x)
        return x1, x2, x3

modelB = MyModelB(fast_inference=True)
modelB.load_state_dict(torch.load('modelA.pth'))
output = modelB(x)

Would that work for you?

for my current case yes that would work
is this way can save memory (run GPU)?

also is there a way to save the whole model ? architecture + weight ? (as in keras and tensorflow )