I want to train from scratch the AlexNet model:
model = models.alexnet(pretrained=False)
num_features = model.classifier[6].in_features
features = list(model.classifier.children())[:-1] # Remove last layer
features.extend([nn.Linear(num_features, 4)]) # Add our layer with 4 outputs
model.classifier = nn.Sequential(*features) # Replace the model classifier
for param in model.features.parameters():
param.require_grad = True
I want to save the parameters/hyperparmeters (weights, bias, model values/structure etc.).
This is the training function that I am using:
def train(model, dataloaders, criterion1, criterion2, optimizer, num_epochs):
since = time.time()
train_loss = []
val_acc_history = []
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
print('-' * 10)
for phase in ['train', 'validation']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for images, target in dataloaders[phase]:
images = images.to(device)
labels = target['label'].to(device)
norm = target['norm_box'].to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
out_bbox = model(images)
loss_bbox = criterion2(out_bbox, norm)
loss = loss_bbox
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
running_corrects += IoU(norm, out_bbox, 0.3, batch_size)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
if phase == 'train':
train_loss.append(epoch_loss)
epoch_acc = running_corrects / len(dataloaders[phase].dataset)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
if phase == 'validation' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
if phase == 'validation':
val_acc_history.append(epoch_acc)
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
torch.save(best_model_wts, 'weights/best_wts.pth')
return model, train_loss, val_acc_history
When I load the state_dict
:
wts = model.load_state_dict(torch.load('weights/best_wts.pth))
I get the message:
IncompatibleKeys(missing_keys=[], unexpected_keys=[])
And when trying to look through the state_dict:
print(list(wts.keys()))
I get the error:
AttributeError: 'IncompatibleKeys' object has no attribute 'keys'
Any advice would be greatly appreciated.