How to define and load reduced precision LibTorch model?


Let us say I have a simple model as follows:

auto sequential = nn::Sequential(
  nn::Linear(4, 4),
  nn::GroupNorm(1, 4),
  nn::Linear(4, 4)
torch::load(sequential, "./");

Instead of defining model (with dtype torch::kFloat32 by default) and converting it to torch::kFloat16, how do I go about defining it as a torch::kFloat16 model in the first place?

I have a large model and would like to avoid large (fp32) memory allocation before whittling it down to fp16 and loading.