Calculate derivative of function approximated by NN

Hi all! I want to do an approximation of a function and its derivative using neural networks and pytorch. The difficulty is that in the future problem I want to deal with the derivative is not given so the net can’t learn the derivative. As a result I want the network to be able to derive the function it is approximating by itself. Another difficulty for me is that I want to normalize the input and output without destroying the autograd graph. For normalization I am using sklearns MinMaxScaler but for revertion of normalization I can not use it directly as I would have to detach the pytorch tensor. So I am getting the min_ and scale_ property to do the reversion by myself. So the overall implementation looks like this:

def calculate_jacobi(model, inp_data, inp_scaler, out_scaler):
	normalized_input = normalize_inp(inp_data, inp_scaler)
	# set requires_grad to True so that the gradient is calculated with respect to input
	normalized_input.requires_grad = True

	# set requires_grad for all network parameters to False so that the gradient is not
	# calculated w.r.t. the network parameters
	for param in model.parameters():
	    param.requires_grad = False

	output = model(normalized_input)
	unnorm_output = revert_output_normalization_keeping_gradient(output)

	# calculate gradient for every output dimension -> jacobi matrix
	jacobi = []
	for out in unnorm_output:
		# need to set retain_graph to True as backward path is done multiple times
		out.backward(retain_graph=True)
		jacobi.append(normalized_input.grad.numpy())

	return np.array(jacobi)

def normalize_inp(data, scaler):
	return torch.FloatTensor(scaler.transform(data))

def revert_output_normalization_keeping_gradient(normalized_output, scaler):
	unnorm_output = torch.zeros(normalized_output.shape)
	for dim in range(0, len(unnorm_output[0])):
		minimum = scaler.min_[dim]
		scale = scaler.scale_[dim]
		unnorm_output[:, dim] = torch.div(torch.sub(normalized_output[:,dim], minimum), scale)

	return unnorm_output

Is my idea and this implementation correct or am I missing something? So as a next step I tried to validate my idea with a simple example function: x ** 3 + 7 - 3 * x where I can also test the derivative. I generated 10 example points and checked the network output and the above calculated derivative. My results were the following:

**
f(1.872678462764552) = 7.94930682729936
model(1.872678462764552) = 9.3166685

f_deriv(1.872678462764552) = 7.520773874706617
model_deriv(1.872678462764552) = 49.068344
**
f(2.4448616490329247) = 14.279205121677888
model(2.4448616490329247) = 14.303566

f_deriv(2.4448616490329247) = 14.932045448735973
model_deriv(2.4448616490329247) = 134.91791
**
f(4.4445634270474565) = 81.46485592743929
model(4.4445634270474565) = 82.55166

f_deriv(4.4445634270474565) = 56.262432171143494
model_deriv(4.4445634270474565) = 566.58594
**
f(6.139904037163629) = 220.04497877649737
model(6.139904037163629) = 219.52823

f_deriv(6.139904037163629) = 110.0952647567347
model_deriv(6.139904037163629) = 1062.7377
**
f(2.571250050248884) = 16.28562426971684
model(2.571250050248884) = 16.192503

f_deriv(2.571250050248884) = 16.83398046271467
model_deriv(2.571250050248884) = 161.14128
**
f(2.817382806029692) = 20.91123859249368
model(2.817382806029692) = 20.832308

f_deriv(2.817382806029692) = 20.812937627135227
model_deriv(2.817382806029692) = 211.47461
**
f(3.332011539496338) = 33.996960093061624
model(3.332011539496338) = 34.29489

f_deriv(3.332011539496338) = 30.30690269801027
model_deriv(3.332011539496338) = 305.8891
**
f(7.201266008061427) = 358.8411261715284
model(7.201266008061427) = 354.93872

f_deriv(7.201266008061427) = 152.57469635658288
model_deriv(7.201266008061427) = 1469.3788
**
f(2.04220977299413) = 9.390653150236977
model(2.04220977299413) = 10.303757

f_deriv(2.04220977299413) = 9.51186227073821
model_deriv(2.04220977299413) = 67.201416
**
f(5.550606341194321) = 161.35809257192074
model(5.550606341194321) = 162.24344

f_deriv(5.550606341194321) = 89.42769226471982
model_deriv(5.550606341194321) = 867.60986

So the results of the function approximation are quite ok. But the results of the derivation are always approximately 10 times bigger than the correct derivative. And I really do not now where this factor comes from… I would really appreciate any suggestions for finding the point where I am going wrong.

you should probably use torch.autograd.functional.jacobian, with your approach you need to zero gradients as you iterate over rows (outputs)

1 Like

Hi Alex, thanks for sharing this link. It helped me finding out where I did go wrong and at the end I got the same result with my implementation as with the jacobian function from pytorch.

So in case anybody is reading this later on: My fault was that I calculated the gradient with respect to the normalized input but I should have done it with respect to the raw unnormalized input. Therefore I needed to add another function which normalizes the raw input and keeps the gradient. So here is the correct implementation:

def calculate_jacobi(model, inp_data, inp_scaler, out_scaler):
	# set requires_grad to True so that the gradient is calculated with respect to unnormalized input
	inp_data.requires_grad = True
	
	normalized_input = normalize_inp_keeping_gradient(inp_data, inp_scaler)

	# set requires_grad for all network parameters to False so that the gradient is not
	# calculated w.r.t. the network parameters
	for param in model.parameters():
	    param.requires_grad = False

	output = model(normalized_input)
	unnorm_output = revert_output_normalization_keeping_gradient(output)

	# calculate gradient for every output dimension -> jacobi matrix
	jacobi_matrices = []
	for out in unnorm_out:
		jacobi_matrix = []
		for out_dim in out:
			out_dim.backward(retain_graph=True)
			# calculate gradient with respect to unnormalized input
			jacobi_matrix.append(inp_data.grad.numpy())
			inp_data.grad = None
		jacobi_matrices.append(np.array(jacobi_matrix))

	return np.array(jacobi_matrices)

def normalize_inp_keeping_gradient(inp, scaler):
	normalized_inp = torch.zeros(inp.shape)
	for dim in range(0, len(normalized_inp[0])):
		minimum = scaler.min_[dim]
		scale = scaler.scale_[dim]
		normalized_inp[:, dim] = torch.add(torch.mul(inp[:, dim], scale), minimum)

	return normalized_inp

def revert_output_normalization_keeping_gradient(normalized_output, scaler):
	unnorm_output = torch.zeros(normalized_output.shape)
	for dim in range(0, len(unnorm_output[0])):
		minimum = scaler.min_[dim]
		scale = scaler.scale_[dim]
		unnorm_output[:, dim] = torch.div(torch.sub(normalized_output[:,dim], minimum), scale)

	return unnorm_output

And corresponding to the implementation above one could use the jacobian function of pytorch as follows:

example_input = torch.FloatTensor([[5.543]])
torch.autograd.functional.jacobian(func_to_derive, example_input, create_graph=True))

where the func_to_derive is defined as follows:

def func_to_derive(model, inp_data, inp_scaler, out_scaler):
	inp_data.requires_grad = True

	# normalize input with keeping its gradient
	normalized_inp_data = inp_scaler.normalize_inp_keeping_gradient(inp_data)

	# set requires_grad for all network parameters to False so that the gradient is not
	# calculated w.r.t. the network parameters
	for param in model.parameters():
		param.requires_grad = False

	# calculate corresponding output
	normalized_output = model(normalized_inp_data)

	# revert normalization for output so gradient of unnormalized output can be calculated
	unnorm_out = out_scaler.revert_output_normalization_keeping_gradient(normalized_output)

	return unnorm_out