Saving TensorDictModule

Hey!
What is the recommended approach for saving a trained model to disk in TorchRL?

As opposed to standard Torch, saving nn.Module directly might not be the best way (for example, in cases where we chain using TensorDictSequential).

Much appreciated!

Hey,

I am new here, but wanted to try and help out. I threw your question into Claude Sonnet 3.5 (My teacher). Sonnet recommended a few methods but .state_dict() seems like the winner.

  1. Saving the State Dict: The most flexible and recommended way is to save the state dictionary of your model. This works well even with complex TorchRL structures.

To save:
torch.save(model.state_dict(), “model_state_dict.pth”)

To load:
model.load_state_dict(torch.load(“model_state_dict.pth”))

Note: .state_dict() should work well for TensorDictSequential.
Ex: torch.save(tensor_dict_sequential.state_dict(), “tensordict_sequential_state.pth”)

  1. Using TensorDict: If your model uses TensorDict structures, you might want to save the entire TensorDict:

from tensordict import TensorDict

To save:
Assuming your model state is in a TensorDict,
state_dict = TensorDict({“model”: model.state_dict()})
torch.save(state_dict, “model_tensordict.pth”)

To load:
loaded_state = torch.load(“model_tensordict.pth”)
model.load_state_dict(loaded_state[“model”])

  1. Saving the Entire Model: While not always recommended, especially for complex structures, you can save the entire model:

torch.save(model, “full_model.pth”)

Load with:
model = torch.load(“full_model.pth”)

I look forward to seeing what others recommend, and hope this helps.

Regards,
Kyle

1 Like

Thanks for the answer mate,
You’re right. Didn’t think it’s so well supported.

For the sake of completeness, below is a code that verifies it:

import copy

import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn


def save_checkpoint_using_state_dict(module: TensorDictModule, dst_path: str) -> None:
    torch.save(module.state_dict(), dst_path)


def load_checkpoint_using_state_dict(module: TensorDictModule, src_path: str) -> TensorDictModule:
    module.load_state_dict(torch.load(src_path))
    return module


def save_checkpoint(module: TensorDictModule, dst_path: str) -> None:
    torch.save(module, dst_path)


def load_checkpoint(src_path: str) -> TensorDictModule:
    return torch.load(src_path)


if __name__ == '__main__':
    path = '/tmp/model.pth'
    m1 = nn.Linear(1, 1)
    m2 = nn.Linear(1, 1)
    module1 = TensorDictModule(m1, in_keys='inp', out_keys='out')
    module2 = TensorDictModule(m2, in_keys='out', out_keys='out2')
    module = TensorDictSequential(module1, module2)
    module_copy = copy.deepcopy(module)

    # using state dict:
    save_checkpoint_using_state_dict(module, path)
    with torch.no_grad():
        [p.zero_() for p in module.parameters()]

    module = load_checkpoint_using_state_dict(module, path)
    for p1, p2 in zip(module.parameters(), module_copy.parameters()):
        assert torch.equal(p1.data, p2.data)

    # no state dict:
    save_checkpoint(module, path)
    with torch.no_grad():
        [p.zero_() for p in module.parameters()]

    module = load_checkpoint(path)
    for p1, p2 in zip(module.parameters(), module_copy.parameters()):
        assert torch.equal(p1.data, p2.data)

Cheers!

2 Likes