Extending Pytorch Gradients

Hello all!

I am trying to make stride to recreate the code found here, GitHub - google-research/neural-structural-optimization: Neural reparameterization improves structural optimization, in pytorch. In their code base in the neural-structural-optimization/autograd_lib.py at master · google-research/neural-structural-optimization · GitHub the extend a function called root_finder. Unfortunately, for myself I am still in the weeds of learning torch and was wondering if someone might have some insights about how to translate some functions to torch.

These are the functions I am trying to replicate:

import autograd
import autograd.core
import autograd.extend
import autograd.numpy as anp
import numpy as np

# an alternative to the optimality criteria
def sigmoid_with_constrained_mean(x, average):
	f = lambda x, y: sigmoid(x + y).mean() - average
	lower_bound = logit(average) - np.max(x)
	upper_bound = logit(average) - np.min(x)
	b = autograd_lib.find_root(f, x, lower_bound, upper_bound)

	return sigmoid(x + b)


# internal utilities
def _grad_undefined(_, *args):
	raise TypeError('gradient undefined for this input argument')


def _zero_grad(_, *args, **kwargs):
	def jvp(grad_ans):
    	return 0.0 * grad_ans
	return jvp


@autograd.primitive
def find_root(
	f, x, lower_bound, upper_bound, tolerance=1e-12, max_iterations=64
):
	# Implicitly solve f(x,y)=0 for y(x) using binary search.
	# Assumes that y is a scalar and f(x,y) is monotonic in y.
	for _ in range(max_iterations):
		y = 0.5 * (lower_bound + upper_bound)
		if upper_bound - lower_bound < tolerance:
			break
		if f(x, y) > 0:
			upper_bound = y
		else:
			lower_bound = y

		return y


def grad_find_root(y, f, x, lower_bound, upper_bound, tolerance=None):
	# This uses a special case of the adjoint gradient rule:
	# http://www.dolfin-adjoint.org/en/latest/documentation/maths/3-gradients.html#the-adjoint-approach
	def jvp(grad_y):
		g = lambda x: f(x, y)
		h = lambda y: f(x, y)
		gradient_value = -autograd.grad(g)(x) / autograd.grad(h)(y) * grad_y
		return gradient_value
	return jvp


autograd.extend.defvjp(
	find_root,
	_grad_undefined,
	grad_find_root,
    _zero_grad,
    _zero_grad,
    _zero_grad
)

and this is what I have done so far:

def sigmoid_with_constrained_mean(x, average):
    """
    Function that will compute the sigmoid with the contrained
    mean
    """
    if not torch.is_tensor(average):
        average = torch.tensor(average)

    f = lambda x, y: torch.sigmoid(x + y).mean() - average  # noqa
    lower_bound = torch.logit(average) - torch.max(x)
    upper_bound = torch.logit(average) - torch.min(x)
    b = find_root(f, x, lower_bound, upper_bound)

    return torch.sigmoid(x + b)


def find_root(
    f, x, lower_bound, upper_bound, tol=1e-12, max_iterations=65
):
    """
    Will find the root for the volume constraint
    """
    for i in torch.arange(max_iterations):
        y = 0.5 * (lower_bound + upper_bound)
        if (upper_bound - lower_bound) < tol:
            break

        if f(x, y) > 0:
            upper_bound = y

        else:
            lower_bound = y
    return y

I am wondering how to get the equivalent as above. Thank you for any comments or help in advance! :tada: