I am working with a multi-head architecture. The issue is when I complete training and am done with inference, I need to pass the model to an API to do another task. Now the API only works with a single prediction (single-head). I can easily chop off the additional heads of the model and keep one specific head and pass it to the API. But the issue is the model definition in the forward() method I was returning multiple predictions generated by multiple heads. For the API to work I need to modify what I was returning and only return a single thing. Is there a way to do that ? I mean modifying what the model was returning from the trained model .pth file?
You could wrap forward()
in another method that calls forward()
under the hood and returns only what you need. Alternatively load the model weights into another class that has the same structure but has the modifications applied in the forward() method itself, and you should be good to go!
Thank you. I understand what you mean. I can define another model with a modified forward method and name it “model2” and load the weights from “model” to “model2”. Can you direct me to any resource for how I do that (the transfer weight to another model part)?
You don’t need to “transfer” weights explicitly. You can just load the weights from the saved file into the new model directly. Something like this:
class Model1(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(10, 20)
def forward(self, x):
x = x + 10.
return self.lin(x)
class Model2(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(10, 20)
def forward(self, x):
x = x * 10.
return self.lin(x)
m1 = Model1()
torch.save(m1.state_dict(), "test_model1.pt")
m2 = Model2()
m1.load_state_dict(torch.load("test_model1.pt"))
1 Like