OpenGL, Libtorch and cuda interop. Doing inference on texture data

So i’m running an en/decoder network for image colorization on Opengl texture data.
wanting to eek out as much performance as possible i want to keep all the data on the GPU so I’m using CUDA/OpenGL interop to map the texture as graphics resource, create a CUDA array with it using cudaGraphicsSubResourceGetMappedArray, copying the data to a device pointer that i use torch::from_blob() to create a torch tensor with.
then if i just copy this tensor to another CUDA array wich i’ve mapped another texture to, i get a exact copy of the original image back to the new rendered texture.

The problem is that i’m not sure how to get the RGBA texture data into a (c,h,w) format that the network can use, and then convert to a format that OpengGl may use again. currently the model expects a 1,1,299,299 tensor (the L in CIELAB color space) and outputs a (2,299,299) tensor (the AB channels). The tensor must then be concatenated before copied to the texture mapped CUDA array. This is the current output when running the texture data through the network: (299x299 texture in left corner)

replacing shading with flat white color:

image

here are some snippets of code:

// init
	CHECK_CUDA_ERROR(cudaGraphicsGLRegisterImage(&m_imageTextureResource, m_imageTexture->id(), GL_TEXTURE_2D, cudaGraphicsRegisterFlagsNone));
	CHECK_CUDA_ERROR(cudaGraphicsGLRegisterImage(&m_colorTextureResource, m_colorTexture->id(), GL_TEXTURE_2D, cudaGraphicsRegisterFlagsNone));
//.....
//.....

//rendering
//.....
// map resource for use with cuda
	CHECK_CUDA_ERROR(cudaGraphicsMapResources(1, &m_imageTextureResource));
	CHECK_CUDA_ERROR(cudaGraphicsMapResources(1, &m_colorTextureResource));
	// get cuda array from mapped resource
	cudaArray_t imageDeviceArray;
	CHECK_CUDA_ERROR(cudaGraphicsSubResourceGetMappedArray(&imageDeviceArray, m_imageTextureResource, 0, 0));

	cudaArray_t colorDeviceArray;
	CHECK_CUDA_ERROR(cudaGraphicsSubResourceGetMappedArray(&colorDeviceArray, m_colorTextureResource, 0, 0));
	cudaChannelFormatDesc desc;
	cudaExtent extent;
	uint flags;
	CHECK_CUDA_ERROR(cudaArrayGetInfo(&desc, &extent, &flags, imageDeviceArray));

	// to be able to pass the data to torch, we need to copy the cuda array to a device pointer
	// here, we allocate the necessary memory
	void* devicePtr = nullptr;
	CHECK_CUDA_ERROR(cudaMalloc(&devicePtr, extent.width* extent.height * 4 * sizeof(float)));

	// now we copy the array to the device pointer
	CHECK_CUDA_ERROR(cudaMemcpy2DFromArray(devicePtr, extent.width * 4 * sizeof(float), imageDeviceArray, 0, 0, extent.width * 4 * sizeof(float), extent.height, cudaMemcpyDeviceToDevice));
	{
int64_t(extent.width) };
		//auto tensor = torch::from_blob(devicePtr,{1,3,long long(extent.height) ,long long(extent.width) }, torch::TensorOptions().dtype(torch::kFloat32).layout(torch::kStrided).device(torch::kCUDA));
		
		auto texture_tensor = torch::from_blob(devicePtr, { 1, long long(extent.height), long long(extent.width), 3 }, torch::TensorOptions().dtype(torch::kF32).layout(torch::kStrided).device(torch::kCUDA));
		texture_tensor = texture_tensor.permute({ 0,3,1,2 });
		//auto test = torch::ones(torch::IntArrayRef{ int64_t(1), int64_t(3), int64_t(m_imageSize.y), int64_t(m_imageSize.x) }, torch::TensorOptions().dtype(torch::kFloat32).layout(torch::kStrided).device(torch::kCUDA));

		try {

			std::vector<torch::jit::IValue> inputs;
			inputs.push_back(texture_tensor);
			auto embedding = m_ModelInceptionV3.forwardPass(inputs);

			inputs.clear();
			auto encoderL = texture_tensor.index({ 0,0 }).unsqueeze(0).unsqueeze(0);
			std::cout << "encoder L: " << encoderL.sizes() << std::endl;
			inputs.push_back(encoderL);
			inputs.push_back(embedding);

			auto EncoderAB = m_ModelEncoder.forwardPass(inputs);

			auto test = torch::cat({ encoderL.squeeze(0) ,EncoderAB.squeeze()}, 0);

			cudaMemcpy2DToArray(colorDeviceArray, 0, 0, test.data_ptr<float>(), m_imageSize.x * 4 * sizeof(float), m_imageSize.x * 4 * sizeof(float), m_imageSize.y, cudaMemcpyDeviceToDevice);
		}
		catch (const c10::Error& e) {
			std::cerr << "inference error\n" << e.what() << std::endl;
		}
	
	}

	CHECK_CUDA_ERROR(cudaFree(devicePtr));
	devicePtr = nullptr;

	// unmap resource (to be able to use it in OpenGL again)
	CHECK_CUDA_ERROR(cudaGraphicsUnmapResources(1, &m_imageTextureResource));
	CHECK_CUDA_ERROR(cudaGraphicsUnmapResources(1, &m_colorTextureResource));