# Extending Pytorch Gradients

Hello all!

I am trying to make stride to recreate the code found here, GitHub - google-research/neural-structural-optimization: Neural reparameterization improves structural optimization, in `pytorch`. In their code base in the neural-structural-optimization/autograd_lib.py at master · google-research/neural-structural-optimization · GitHub the extend a function called `root_finder`. Unfortunately, for myself I am still in the weeds of learning `torch` and was wondering if someone might have some insights about how to translate some functions to `torch`.

These are the functions I am trying to replicate:

``````import autograd
import autograd.numpy as anp
import numpy as np

# an alternative to the optimality criteria
def sigmoid_with_constrained_mean(x, average):
f = lambda x, y: sigmoid(x + y).mean() - average
lower_bound = logit(average) - np.max(x)
upper_bound = logit(average) - np.min(x)
b = autograd_lib.find_root(f, x, lower_bound, upper_bound)

return sigmoid(x + b)

# internal utilities
raise TypeError('gradient undefined for this input argument')

def _zero_grad(_, *args, **kwargs):
return 0.0 * grad_ans
return jvp

def find_root(
f, x, lower_bound, upper_bound, tolerance=1e-12, max_iterations=64
):
# Implicitly solve f(x,y)=0 for y(x) using binary search.
# Assumes that y is a scalar and f(x,y) is monotonic in y.
for _ in range(max_iterations):
y = 0.5 * (lower_bound + upper_bound)
if upper_bound - lower_bound < tolerance:
break
if f(x, y) > 0:
upper_bound = y
else:
lower_bound = y

return y

def grad_find_root(y, f, x, lower_bound, upper_bound, tolerance=None):
# This uses a special case of the adjoint gradient rule:
g = lambda x: f(x, y)
h = lambda y: f(x, y)
return jvp

find_root,
)
``````

and this is what I have done so far:

``````def sigmoid_with_constrained_mean(x, average):
"""
Function that will compute the sigmoid with the contrained
mean
"""
if not torch.is_tensor(average):
average = torch.tensor(average)

f = lambda x, y: torch.sigmoid(x + y).mean() - average  # noqa
lower_bound = torch.logit(average) - torch.max(x)
upper_bound = torch.logit(average) - torch.min(x)
b = find_root(f, x, lower_bound, upper_bound)

def find_root(
f, x, lower_bound, upper_bound, tol=1e-12, max_iterations=65
):
"""
Will find the root for the volume constraint
"""
for i in torch.arange(max_iterations):
y = 0.5 * (lower_bound + upper_bound)
if (upper_bound - lower_bound) < tol:
break

if f(x, y) > 0:
upper_bound = y

else:
lower_bound = y
return y
``````

I am wondering how to get the equivalent as above. Thank you for any comments or help in advance! 