How can we use autograd when using the function implemented in c

Hi, I want to use the function implemented by the c language

First the overall flow is this
input --> NN --> param --> [function(C language) --> output] --> loss(output - reference)
[] is done with c.

It’s too time-consuming to implement the function implemented by c into python.
Also, the calculation of the function is specialized in c because of intel MKL.

So I tried to use ctypes to send arrays from python to c.
But in this case, I have to detach tensor and autograd.

After calculation in c language, I send arrays to python and make it be tensor type.
But when I tried to train NN, I can’t update the weight.
This is of course because the autograd is not existed.

Here comes the problem.
First, I don’t know how to make autograd of function, c part.
Second, I don’t know how to update the last of the part, the NN part without using loss.backward.

how can I do it?
here is the code

for epoch in range(num_epoch):
	scheduler.step
	running_loss = 0

	# train
	for i, (inputs, raws_in, targets, raws_out) in enumerate(train_loader):
		inputs = inputs.to(device)
		targets = targets.to(device)			

		# Polynomial model - python version
		# params = model(inputs)
		# inputs = torch.stack([inputs[:, 0] * inputs[:, 0], inputs[:, 0], torch.FloatTensor([1, 1, 1, 1, 1, 1]).to(device), inputs[:, 1]])
		# outputs = params.matmul(inputs)
		# outputs = torch.diag(outputs).view([-1, 1])

		# sc3d code input - c
		params = model(inputs)
		grad = params.grad

		print(inputs)
		print(params)
		print('----------------------------------------------')

		params_c = params.to(torch.device('cpu'))
		params_c = params_c.detach().numpy()

		inputs_c = torch.stack([inputs[:, 0] * inputs[:, 0], inputs[:, 0], torch.FloatTensor([1, 1, 1, 1, 1, 1]).to(device), inputs[:, 1]])
		inputs_c = inputs_c.to(torch.device('cpu'))
		inputs_c = inputs_c.detach().numpy()

		libname = os.path.abspath(os.path.join(os.path.dirname(__file__), "sc3d.so"))
		LIBC = ctypes.CDLL(libname)

		_double_pp = np.ctypeslib.ndpointer(dtype=np.uintp, ndim=1)
		# _pp = np.ctypeslib.ndpointer(dtype=ctypes.c_float, ndim=1)

		LIBC.polyno.argtypes = [ctypes.c_int, ctypes.c_int, _double_pp, _double_pp, _double_pp] 
		LIBC.polyno.restype = None
		
		params_pp = (params_c.__array_interface__['data'][0] + np.arange(params_c.shape[0])*params_c.strides[0]).astype(np.uintp) 
		inputs_pp = (inputs_c.__array_interface__['data'][0] + np.arange(inputs_c.shape[0])*inputs_c.strides[0]).astype(np.uintp) 
		
		m = ctypes.c_int(params_pp.shape[0])
		n = ctypes.c_int(inputs_pp.shape[0]) 

		# outputs = np.zeros_like(m, m) # param(6*4) * input(4*6)
		outputs = np.zeros((6, 1)) # param(6*4) * input(4*6)
		outputs_pp = (outputs.__array_interface__['data'][0] + np.arange(outputs.shape[0])*outputs.strides[0]).astype(np.uintp)

		LIBC.polyno(m, n, params_pp, inputs_pp, outputs_pp)
		outputs = torch.FloatTensor(outputs).to(device)
		# outputs.requires_grad_(True)

		print(inputs_c)
		print(params_c)
		print(outputs)
		print(targets)

		# Forward pass
		loss = criterion(outputs, targets)
		# loss = torch.sqrt(criterion(outputs, targets))

		# Backward and optimize
		optimizer.zero_grad()
		loss.backward(gradient=params.grad)
		
		optimizer.step()
		running_loss += float(losses(outputs, targets))
		if (i + 1) % train_step == 0:
			print('Epoch [{}/{}], Step [{}/{}], Train Loss: {}'
						.format(epoch + 1, num_epoch, i + 1, train_step, running_loss/train_step), end='')
	train_losses.append(running_loss / train_step)

Hi,

In general, I am afraid that you cannot do that. Only operations within torch will be tracked by autograd which enables you to do the backward and update params.
But there is one exception, which is the fact that you can extend autograd.Function and use C/C++ or python` codes easily in it to enable autograd tracking for computations that pytorch do not support internally. For instance, you cannot use numpy and still track grads, but if you can obtain derivitives manually, then you can simply create a box that does that and let you to do the backward.

Here are the links that help you throughout the process and I think documentation is clear enough for a start:

  1. Learn about autograd:
  1. extending pytorch (autograd, nn.module, etc):
  1. Tutorial on a custom layer/network with manual autograd
  1. Extending pytorch by C++/cuda and creating binding directly from C++ for pytorch:

Bests

1 Like

Hi,
Thank you for your kindness.

But still, there is one problem.
I don’t know how to make the backward part.
The backward part consisted of the gradient of the function is made with the eigenvalue problem.
So, I can’t exactly describe what the gradient is.

In the end, how can I make the gradient of the black box in terms of autograd
Here is the code

class sc3d(torch.autograd.Function):

@staticmethod
def forward(ctx, input, param):

	ctx.save_for_backward(input, param)

	param_c = param.to(torch.device('cpu'))
	param_c = param_c.detach().numpy()
	
	input_c = torch.stack([input[:, 0] * input[:, 0], input[:, 0], torch.FloatTensor([1, 1, 1, 1, 1, 1]).to(device), input[:, 1]])
	input_c = input_c.to(torch.device('cpu'))
	input_c = input_c.detach().numpy()
	
	libname = os.path.abspath(os.path.join(os.path.dirname(__file__), "sc3d.so"))
	LIBC = ctypes.CDLL(libname)

	_double_pp = np.ctypeslib.ndpointer(dtype=np.uintp, ndim=1)
	# _pp = np.ctypeslib.ndpointer(dtype=ctypes.c_float, ndim=1)
	
	LIBC.polyno.argtypes = [ctypes.c_int, ctypes.c_int, _double_pp, _double_pp, _double_pp] 
	LIBC.polyno.restype = None
	
	param_pp = (param_c.__array_interface__['data'][0] + np.arange(param_c.shape[0])*param_c.strides[0]).astype(np.uintp) 
	input_pp = (input_c.__array_interface__['data'][0] + np.arange(input_c.shape[0])*input_c.strides[0]).astype(np.uintp) 
	
	m = ctypes.c_int(param_pp.shape[0])
	n = ctypes.c_int(input_pp.shape[0]) 
	
	# outputs = np.zeros_like(m, m) # param(6*4) * input(4*6)
	output = np.zeros((6, 1)) # param(6*4) * input(4*6)
	output_pp = (output.__array_interface__['data'][0] + np.arange(output.shape[0])*output.strides[0]).astype(np.uintp)
	
	LIBC.polyno(m, n, param_pp, input_pp, output_pp)
	output = torch.FloatTensor(output).to(device)
	# outputs.requires_grad_(True)

	return output

@staticmethod
def backward(ctx, grad_output):

	input, param = ctx.saved_tensors
	grad_input = grad_param = None

	# calculate gradient part
	# grad_param = grad_output.clone()
	
	return grad_input, grad_param

Best regards,

Choi

Hi Choi!

If your main roadblock really is just backpropagating through the
“eigenvalue problem,” then you could perform that part of the
calculation using pytorch’s torch.eig() (or torch.symeig()),
and use autograd to backpropagate through the eigenvalue piece.

If, furthermore, the rest of your loss function can also be written with
pytorch tensor functions, you could use autograd to backpropagate
through the entire loss function.

But even if not, you could:

loss = (tensor functions) --> c-function-A --> torch.eig() --> c-function-B --> (more tensor functions)

If you can figure out the gradient for c-function-A and c-function-B
(even if you can’t for the eigenvalue problem), you can write your
custom c functions with both .forward() and .backward() functions,
and use autograd to backpropagate through the whole thing. This way,
you’re letting pytorch’s autograd do the “hard part” (namely, calculating
the gradient of the eigenvalue problem) for you, while you write the
gradients (the .backward() functions) for c-function-A and
c-function-B, which you can’t, or don’t want to, write using pytorch
tensor functions.

Best.

K. Frank

1 Like

Hi
Thanks for your answer.

But there is one more question.
As you said, if I use c-function-A, then I have to calculate the gradient of c-function-A with the analytic form of c-function-A?

Is there any way to use the numerical form of c-function-A
Like for example, can I make [ c-function-A(x-delta x) / x-delta x ] as a backward part?
If I can, is it too slower than the analytic way?

Best regards,
Choi

Hi Choi!

I’ve never tried this approach – using numerical differentiation to
calculate .backward() – but I have thought about it.

Advantages: Numerical differentiation is a very general-purpose tool,
so you can apply it to lots of use cases.

Disadvantages: A sledgehammer is also a general-purpose tool that
can be “applied” to almost anything. But it is a crude tool, so you
may not get the results you want …

More concretely, numerical differentiation can be numerically tricky.
The core issue is how do you choose your step size, delta_x so
that you get a good estimate of the local derivative, but you don’t
get too much round-off error when you take the finite differences?
(Consider numerically differentiating abs (x) when your step size
crosses x = 0.)

(This is a well-studied problem: Many scientific subroutine libraries
include numerical-differentiation algorithms, but if you look at them
you will see that they are much more involved than the simple
approach you outlined above.)

It is likely to be slower than a numerical implementation of the analytic
derivative (by possibly a lot).

If you can calculate the derivative analytically – even if doing so is a
fair amount of work – it will be worth doing it that way. If it were me, I
would only calculate .backward() numerically as a quick-and-dirty
hack to test something or if I were really struggling with the analytic
derivative.

Best.

K. Frank

2 Likes

Hi,
thanks for your answer.

From your words, I decided to use python code only.
But there is one more problem here.

there is no way to use eigenvalue solver with the complex problem right now

Therefore, I’m considering to use an external library.

Did you know any other way to use autograd with the complex eigenvalue problem?
If not, is it possible in tensorflow?

Best regards,
Choi