I am new here, but wanted to try and help out. I threw your question into Claude Sonnet 3.5 (My teacher). Sonnet recommended a few methods but .state_dict() seems like the winner.
Saving the State Dict: The most flexible and recommended way is to save the state dictionary of your model. This works well even with complex TorchRL structures.
To save:
torch.save(model.state_dict(), “model_state_dict.pth”)
To load:
model.load_state_dict(torch.load(“model_state_dict.pth”))
Note: .state_dict() should work well for TensorDictSequential.
Ex: torch.save(tensor_dict_sequential.state_dict(), “tensordict_sequential_state.pth”)
Using TensorDict: If your model uses TensorDict structures, you might want to save the entire TensorDict:
from tensordict import TensorDict
To save:
Assuming your model state is in a TensorDict,
state_dict = TensorDict({“model”: model.state_dict()})
torch.save(state_dict, “model_tensordict.pth”)
To load:
loaded_state = torch.load(“model_tensordict.pth”)
model.load_state_dict(loaded_state[“model”])
Saving the Entire Model: While not always recommended, especially for complex structures, you can save the entire model:
torch.save(model, “full_model.pth”)
Load with:
model = torch.load(“full_model.pth”)
I look forward to seeing what others recommend, and hope this helps.