How to define and load reduced precision LibTorch model?

Hi,

Let us say I have a simple model as follows:

auto sequential = nn::Sequential(
  nn::Linear(4, 4),
  nn::ReLU(),
  nn::GroupNorm(1, 4),
  nn::Linear(4, 4)
);

sequential.to(torch::kFloat16);
torch::load(sequential, "./sequential-fp16.pt");

Instead of defining model (with dtype torch::kFloat32 by default) and converting it to torch::kFloat16, how do I go about defining it as a torch::kFloat16 model in the first place?

I have a large model and would like to avoid large (fp32) memory allocation before whittling it down to fp16 and loading.

Thanks!