Cannot Load Model using Pytorch Lightning

I am attempting to load a trained model using Pytorch Lightning, but I get the following error when I try to load the produced model:

loading state_dict for Learner:  Unexpected key(s) in state_dict:

My code is shown below. Can someone suggest how I fix this problem?

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import sys
import os
    
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers =nn.Sequential(
            nn.Linear(5,64),
            nn.Tanh(),
            nn.Linear(64,64),
            nn.Tanh(),
            nn.Linear(64,1)
        )
       
   
    def forward(self, x):
        x = self.layers(x.float())
        return x
        
class Learner(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.cuda()
        y = y.cuda()
        y_hat = torch.sum(self.model(x), axis=1)
        loss = nn.MSELoss()(y_hat, y)
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.005)

    def train_dataloader(self):
        return trainloader
         
data = torch.load("/some/data/to/load/data.pt")
truth = torch.load("/some/data/to/load/truth.pt")

dataset = torch.utils.data.TensorDataset(data, truth)
subset_indices = np.arange(0,100000, 1)
subset = torch.utils.data.Subset(dataset, subset_indices)
trainloader = torch.utils.data.DataLoader(
             subset,
             batch_size=15,
             shuffle=True,
             num_workers=5
 )
   
checkpoint_callback = ModelCheckpoint(  filepath="./lightning_logs/version_5364248/checkpoints/epoch=0.ckpt",
     verbose=True,
     monitor='val_loss',
     mode='min'
 )

model = Network()
learn = Learner(model).cuda()

##CODE FAILS HERE###
learn.load_from_checkpoint("./lightning_logs/version_5364248/checkpoints/epoch=0.ckpt")
##CODE FAILS HERE###

trainer = pl.Trainer(min_epochs=1, max_epochs=1, checkpoint_callback=checkpoint_callback, gpus=1)
trainer.fit(learn)

Are you seeing any unexpected keys in the state_dict or is the warning empty?

There are unexpected keys, all of which are from ‘model’.

However, when I run 'print (learn)', I see that all of the layers are correctly defined in the Learner before I load the function.

Unfortunately, I’m not familiar enough with Lightning, so adding @williamFalcon to the topic.