Hi,
Let us say I have a simple model as follows:
auto sequential = nn::Sequential(
nn::Linear(4, 4),
nn::ReLU(),
nn::GroupNorm(1, 4),
nn::Linear(4, 4)
);
sequential.to(torch::kFloat16);
torch::load(sequential, "./sequential-fp16.pt");
Instead of defining model (with dtype torch::kFloat32
by default) and converting it to torch::kFloat16
, how do I go about defining it as a torch::kFloat16
model in the first place?
I have a large model and would like to avoid large (fp32) memory allocation before whittling it down to fp16 and loading.
Thanks!