TypeError Traceback (most recent call last)
<ipython-input-58-a767e129d635> in <module>()
----> 1 m = get_model()
<ipython-input-57-a5c563c6d941> in get_model()
1 def get_model():
2
----> 3 model = MyEnsemble()
4 model.load_state_dict(torch.load("/content/final_model_without_pruning.pt"))
5
TypeError: __init__() missing 2 required positional arguments: 'modelA' and 'modelB'
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB, nb_classes=10):
super(MyEnsemble, self).__init__()
self.modelA = modelA
self.modelB = modelB
# Remove last linear layer
self.modelA.fc = nn.Linear(512, nb_classes)
# self.modelB.classifier = nn.Linear(1024, nb_classes)
self.modelB.classifier[6] = nn.Linear(4096, nb_classes)
# Create new classifier
self.classifier = nn.Linear(10, nb_classes)
def forward(self, x):
x1 = self.modelA(x.clone()) # clone to make sure x is not changed by inplace methods
x1 = x1.view(x1.size(0), -1)
x2 = self.modelB(x)
x2 = x2.view(x2.size(0), -1)
x = torch.cat((x1, x2), dim=1)
x = self.classifier(F.relu(x))
return x
modelA = models.resnet18(pretrained=True)
# modelB = models.densenet121(pretrained=True)
modelB = models.vgg16(pretrained=True)
# Freeze these models
for param in modelA.parameters():
param.requires_grad_(False)
for param in modelB.parameters():
param.requires_grad_(False)
# Create ensemble model
model = MyEnsemble(modelA, modelB, 5)
# x = torch.randn(1, 3, 224, 224)
# output = model(x)
def optimize(train_dataloader, valid_dataloader, model, loss_fn, optimizer, nb_epochs):
train_losses = []
valid_losses = []
train_acc = []
valid_acc = []
best_valid_acc = 0
for epoch in range(nb_epochs):
print(f'\nEpoch {epoch+1}/{nb_epochs}')
print('-------------------------------')
train_loss, train_accuracy = train(train_dataloader, model, loss_fn ,optimizer)
train_losses.append(train_loss)
train_acc.append(train_accuracy)
valid_loss, validation_accuracy = validate(valid_dataloader, model, loss_fn)
valid_losses.append(valid_loss)
valid_acc.append(validation_accuracy)
with open("/content/drive/MyDrive/Colab Notebooks/train_losses.txt", "wb") as fp: #Pickling
pickle.dump(train_losses, fp)
with open("/content/drive/MyDrive/Colab Notebooks/train_acc.txt", "wb") as tp: #Pickling
pickle.dump(train_acc, tp)
with open("/content/drive/MyDrive/Colab Notebooks/val_losses.txt", "wb") as vp: #Pickling
pickle.dump(valid_losses, vp)
with open("/content/drive/MyDrive/Colab Notebooks/val_acc.txt", "wb") as va: #Pickling
pickle.dump(valid_acc, va)
if validation_accuracy > best_valid_acc:
torch.save(model.state_dict(), '/content/drive/MyDrive/Colab Notebooks/final_model_without_pruning.pt')
torch.save(model.state_dict(), '/content/final_model_without_pruning.pt')
best_valid_acc = validation_accuracy
print('\nTraining has completed!')
return train_losses, valid_losses, train_acc, valid_acc
def get_model():
model = MyEnsemble()
model.load_state_dict(torch.load("/content/final_model_without_pruning.pt"))
model.cuda()
model.eval()
return model
m = get_model()