I am attempting to load a trained model using Pytorch Lightning, but I get the following error when I try to load the produced model:
loading state_dict for Learner: Unexpected key(s) in state_dict:
My code is shown below. Can someone suggest how I fix this problem?
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import sys
import os
class Network(nn.Module):
def __init__(self):
super().__init__()
self.layers =nn.Sequential(
nn.Linear(5,64),
nn.Tanh(),
nn.Linear(64,64),
nn.Tanh(),
nn.Linear(64,1)
)
def forward(self, x):
x = self.layers(x.float())
return x
class Learner(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
x = x.cuda()
y = y.cuda()
y_hat = torch.sum(self.model(x), axis=1)
loss = nn.MSELoss()(y_hat, y)
logs = {'train_loss': loss}
return {'loss': loss, 'log': logs}
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.005)
def train_dataloader(self):
return trainloader
data = torch.load("/some/data/to/load/data.pt")
truth = torch.load("/some/data/to/load/truth.pt")
dataset = torch.utils.data.TensorDataset(data, truth)
subset_indices = np.arange(0,100000, 1)
subset = torch.utils.data.Subset(dataset, subset_indices)
trainloader = torch.utils.data.DataLoader(
subset,
batch_size=15,
shuffle=True,
num_workers=5
)
checkpoint_callback = ModelCheckpoint( filepath="./lightning_logs/version_5364248/checkpoints/epoch=0.ckpt",
verbose=True,
monitor='val_loss',
mode='min'
)
model = Network()
learn = Learner(model).cuda()
##CODE FAILS HERE###
learn.load_from_checkpoint("./lightning_logs/version_5364248/checkpoints/epoch=0.ckpt")
##CODE FAILS HERE###
trainer = pl.Trainer(min_epochs=1, max_epochs=1, checkpoint_callback=checkpoint_callback, gpus=1)
trainer.fit(learn)