Saved (Ensemble) Model is not loading

TypeError                                 Traceback (most recent call last)

<ipython-input-58-a767e129d635> in <module>()
----> 1 m = get_model()

<ipython-input-57-a5c563c6d941> in get_model()
      1 def get_model():
      2 
----> 3     model = MyEnsemble()
      4     model.load_state_dict(torch.load("/content/final_model_without_pruning.pt"))
      5 

TypeError: __init__() missing 2 required positional arguments: 'modelA' and 'modelB'
class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB, nb_classes=10):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        # Remove last linear layer
        self.modelA.fc = nn.Linear(512, nb_classes)

        # self.modelB.classifier = nn.Linear(1024, nb_classes)
        self.modelB.classifier[6] = nn.Linear(4096, nb_classes)

        # Create new classifier
        self.classifier = nn.Linear(10, nb_classes)

    def forward(self, x):
        x1 = self.modelA(x.clone())  # clone to make sure x is not changed by inplace methods
        x1 = x1.view(x1.size(0), -1)
        x2 = self.modelB(x)
        x2 = x2.view(x2.size(0), -1)
        x = torch.cat((x1, x2), dim=1)

        x = self.classifier(F.relu(x))
        return x


modelA = models.resnet18(pretrained=True)
# modelB = models.densenet121(pretrained=True)
modelB = models.vgg16(pretrained=True)

# Freeze these models
for param in modelA.parameters():
    param.requires_grad_(False)

for param in modelB.parameters():
    param.requires_grad_(False)

# Create ensemble model
model = MyEnsemble(modelA, modelB, 5)
# x = torch.randn(1, 3, 224, 224)
# output = model(x)

def optimize(train_dataloader, valid_dataloader, model, loss_fn, optimizer, nb_epochs):
    train_losses = []
    valid_losses = []
    train_acc = []
    valid_acc = []
    best_valid_acc = 0

    for epoch in range(nb_epochs):
        print(f'\nEpoch {epoch+1}/{nb_epochs}')
        print('-------------------------------')
        train_loss, train_accuracy = train(train_dataloader, model, loss_fn ,optimizer)
        train_losses.append(train_loss)
        train_acc.append(train_accuracy)
        valid_loss, validation_accuracy = validate(valid_dataloader, model, loss_fn)
        valid_losses.append(valid_loss)
        valid_acc.append(validation_accuracy)

        with open("/content/drive/MyDrive/Colab Notebooks/train_losses.txt", "wb") as fp:   #Pickling
            pickle.dump(train_losses, fp)
        with open("/content/drive/MyDrive/Colab Notebooks/train_acc.txt", "wb") as tp:   #Pickling
            pickle.dump(train_acc, tp)
        with open("/content/drive/MyDrive/Colab Notebooks/val_losses.txt", "wb") as vp:   #Pickling
            pickle.dump(valid_losses, vp)
        with open("/content/drive/MyDrive/Colab Notebooks/val_acc.txt", "wb") as va:   #Pickling
            pickle.dump(valid_acc, va)

        if validation_accuracy > best_valid_acc:
            torch.save(model.state_dict(), '/content/drive/MyDrive/Colab Notebooks/final_model_without_pruning.pt')
            torch.save(model.state_dict(), '/content/final_model_without_pruning.pt')
            best_valid_acc = validation_accuracy
    
    print('\nTraining has completed!')
    return train_losses, valid_losses, train_acc, valid_acc

def get_model():

    model = MyEnsemble()
    model.load_state_dict(torch.load("/content/final_model_without_pruning.pt"))

    model.cuda()
    model.eval()
    return model

m = get_model()

Your constructor expects arguments modelA, modelB and nb_classes. But since, you have defined a default value for nb_classes = 10 in the __init__ method of MyEnsemble, it will work even if you haven’t passed, the value for nb_classes. But modelA and modelB needs to be defined for object creation.

You code should look like:-

m = get_model()

def get_model():
    model = MyEnsemble(modelA = some_model, modelB = some_model, nb_classes = can_skip_passing_but_a_good_practice_to_do_so)`

    ....do_whatever_you_want

1 Like

Oh so sorry but i am not sure about that… Thanks !!! >> MyBalanceNow