Loaded model in Libtorch gives different results each time

Hi, I want to use my pretrained model(which is trained onpPytorch) in my C++ project. I have saved my network in Python and loaded in C++ using jit.

//Python
traced_script_module = torch.jit.trace(model, input)
traced_script_module.save("model.pt")
//C++
torch::jit::script::Module model =torch::jit::load(file_name);

I am passing CUDAFloatTensor as an input. I am supposed to have three outputs from my model.

auto outputs = model.forward({ input });
auto heatmap= outputs.toTuple()->elements()[0].toTensor(); 
auto bbox = outputs.toTuple()->elements()[1].toTensor();
auto scan_rec= outputs.toTuple()->elements()[2].toTensor();

I realized forward pass doesn’t always give correct results. Sometimes it gives answer below.(which is the right answer I checked on Python and outputs on there are always consistent)
first three elements of bbox : 0.4311, 0.4620, 0.3915

But most of the times I get values like this in each three output and I am having exception when I am trying to use my outputs.
first three elements of bbox : = 9.3593e-36, 3.9978e+07,-5.3179e+37

Which is really weird. I am assuming there is some kind of overflow going on. I don’t want to train any network in C++. I just want to use my pretrained model as a deterministic function. I also set these in any case before calling forward pass but it didn’t help. I have no idea what is going on.

torch::manual_seed(0);
torch::NoGradGuard no_grad;

Environment:
Windows 10
CUDA 10.1
Visual Studio 2019
Libtorch 1.5 Nightly Release

I realized my input also changes after I use the forward function. Sometimes I get same input, but most of times it is wrong like outputs. Same as before, I get exception when I try to use my input after forward function.
first three elements of input before forward: 0.3747 0.3584 0.0090
first three elements of input after forward: 3.2892e+08 -nan 4.2039e-45

It seems the tensor is using uninitialized memory. Did you delete or reinitialize the tensor somehow (or are you leaving the scope and thus the tensor might be freed)?

I also tried working in the same scope it didn’t help unfortunately. I didn’t delete or reinitialize the tensor. However my input to the model is actually coming from another network. My code is below.

//this is because of windows.h, delete this after solving loadling library
#define NOMINMAX

#include <ATen/ATen.h>
#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>
#include <torch/all.h>
#include <memory>
#include <Windows.h>
#include <algorithm>
#include <random>
#include <fstream>
#include<tuple>

struct Vox {
	torch::Tensor dims; // 3 dims
	float res;
	torch::Tensor grid2world; // 4x4 dims
	torch::Tensor sdf; // 1 x 1 x dims[0] x dims[1] x dims[2] dims
	
};
Vox load_vox(std::string filename, bool is_cad = false) {

	std::ifstream f(filename, std::ios::binary);
	assert(f.is_open());
	Vox vox;
	std::vector<int32_t> dims;
	float res;
	std::vector<float> grid2world;
	std::vector<float> sdf;
	dims.resize(3);
	grid2world.resize(16);
	f.read((char*)dims.data(), 3 * sizeof(int32_t));
	f.read((char*)&res, sizeof(float));
	f.read((char*)grid2world.data(), 16 * sizeof(float));
	int n_elems = dims[0] * dims[1] * dims[2];

	sdf.resize(n_elems);
	f.read((char*)sdf.data(), n_elems * sizeof(float));

	vox.dims = torch::from_blob(dims.data(), { 3 }, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU));

	vox.grid2world = torch::from_blob(grid2world.data(), { 4,4 }, torch::TensorOptions().dtype(torch::kFloat)).to(at::Device(torch::kCUDA));
	vox.res = res;
	vox.sdf = torch::from_blob(sdf.data(), { 1, 1, dims[2],dims[1],dims[0] }, torch::TensorOptions().dtype(torch::kFloat)).to(at::Device(torch::kCUDA));
	}
	f.close();
	
	return vox;


}
int main() {

	torch::jit::script::Module backbone, model;
	//This is a problem from nightly build. No fix for now. https://github.com/pytorch/pytorch/issues/31611
	LoadLibraryA("torch_cuda.dll");
	
	
	torch::manual_seed(0);
	if (torch::cuda::is_available()) {
		std::cout << "CUDA is available! Training on GPU." << std::endl;
	}
	else {
		std::cout << "CUDA is not available." << std::endl;
		return -1;
	}
	const std::string scene_id = "scene0479_01";
	const std::string scannet_name = "C:/scannet";
	
	const std::string vox_filename = scannet_name + "/scannet_vox/"+scene_id+"/" + scene_id + "_res30mm_rot0.vox";
	Vox v = load_vox(vox_filename);
	
	const std::string& backbone_name = "backbone.pt";
	const std::string& model_name= "model.pt";
	try {
		backbone = torch::jit::load(backbone_name);
		model_object_detection = torch::jit::load(model_name);
	}
	catch (const c10::Error& e) {
		std::cerr << "error loading the model\n";
		return -1;
	}
	torch::NoGradGuard no_grad;
	backbone.eval();
	v.sdf = torch::clamp(v.sdf, -0.15, 2.0);//checked looks fine
	at::Tensor input= backbone.forward({ v.sdf }).toTensor(); //this one works okay
	
	model.eval();
	auto outputs = model.forward({ input});

	auto heatmap= outputs.toTuple()->elements()[0].toTensor(); //checked
	auto bbox = outputs.toTuple()->elements()[1].toTensor();
	auto scan_rec = outputs.toTuple()->elements()[2].toTensor();
    std::cout << input[0][0][0][0] << std::endl;
	
	

	
}

I tested this in the stable version(1.4) and it looks like this code is working without a problem. However I need advanced tensor indexing in my code like these examples below and I believe it is on nightly version only. Is there a equivalent function for these indexing? I might switch to stable version if that is the case

....
using namespace torch::indexing;
target.index_put_({ 0, "...", Slice(tmin[0].item<int>(), tmax[0].item<int>()), Slice(tmin[1].item<int>(), tmax[1].item<int>()), Slice(tmin[2].item<int>(), tmax[2].item<int>()) }, src.index({ 0,"...", Slice(smin[0].item<int>(), smax[0].item<int>()), Slice(smin[1].item<int>(), smax[1].item<int>()), Slice(smin[2].item<int>(), smax[2].item<int>()) }));
...
...
noc = noc.view({ n_batch_size,3,-1 });
mask = mask.view({ n_batch_size,1,-1 }).expand_as(noc);
auto a = noc.index({ i,(mask[i] > 0.5) }).view({ 3,-1 }).to(at::Device(torch::kCUDA));
...
...

giving different results each time may be caused by overlooking model.eval(), I think.