Retrieve the PyTorch model from a PyTorch lightning model

I have trained a PyTorch lightning model that looks like this:

In [16]: MLP
Out[16]:
DecoderMLP(
  (loss): RMSE()
  (logging_metrics): ModuleList(
    (0): SMAPE()
    (1): MAE()
    (2): RMSE()
    (3): MAPE()
    (4): MASE()
  )
  (input_embeddings): MultiEmbedding(
    (embeddings): ModuleDict(
      (LCLid): Embedding(5, 4)
      (sun): Embedding(5, 4)
      (day_of_week): Embedding(7, 5)
      (month): Embedding(12, 6)
      (year): Embedding(3, 3)
      (holidays): Embedding(2, 1)
      (BusinessDay): Embedding(2, 1)
      (day): Embedding(31, 11)
      (hour): Embedding(24, 9)
    )
  )
  (mlp): FullyConnectedModule(
    (sequential): Sequential(
      (0): Linear(in_features=60, out_features=435, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.13371112461182535, inplace=False)
      (3): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (4): Linear(in_features=435, out_features=435, bias=True)
      (5): ReLU()
      (6): Dropout(p=0.13371112461182535, inplace=False)
      (7): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (8): Linear(in_features=435, out_features=435, bias=True)
      (9): ReLU()
      (10): Dropout(p=0.13371112461182535, inplace=False)
      (11): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (12): Linear(in_features=435, out_features=435, bias=True)
      (13): ReLU()
      (14): Dropout(p=0.13371112461182535, inplace=False)
      (15): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (16): Linear(in_features=435, out_features=435, bias=True)
      (17): ReLU()
      (18): Dropout(p=0.13371112461182535, inplace=False)
      (19): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (20): Linear(in_features=435, out_features=435, bias=True)
      (21): ReLU()
      (22): Dropout(p=0.13371112461182535, inplace=False)
      (23): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (24): Linear(in_features=435, out_features=435, bias=True)
      (25): ReLU()
      (26): Dropout(p=0.13371112461182535, inplace=False)
      (27): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (28): Linear(in_features=435, out_features=435, bias=True)
      (29): ReLU()
      (30): Dropout(p=0.13371112461182535, inplace=False)
      (31): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (32): Linear(in_features=435, out_features=435, bias=True)
      (33): ReLU()
      (34): Dropout(p=0.13371112461182535, inplace=False)
      (35): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (36): Linear(in_features=435, out_features=1, bias=True)
    )
  )
)

I need the corresponding PyTorch model to use in one of my other applications.

Is there a simple way to do that?

I thought of saving the checkpoint but then I don’t know how to do it.

Can you please help?
Thanks