Edit the computational graph on the go?

Is it possible to access and edit the graph on the go?

What I mean is the following:

Say, we have variables a and b and we can compute c as a function of a and b, i.e. c = f(a, b). For the forward propagation we need to know a, b and the function and for the backward we need the derivative of the function and then any partial derivatives coming that way.

What if, the function f was too expensive and instead of calculating c based on the function f, we had a way around, say function g. Now, we can’t use function g in our computations as g is not differentiable whereas f is. Is there any way to embed this output c (based on g) for the forward propagation but still use f during backward propagation?

I hope the explanation is clear enough. Let me know if it’s not!

Hi,

If the forward computation of f is too expensive, why is it’s backward feasible? If you use the autograd, the backward will have a similar runtime. You have a simple formula for it?

If the forward computation of f is too expensive, why is it’s backward feasible? If you use the autograd, the backward will have a similar runtime. You have a simple formula for it?

F is expensive (both computationally and memory wise) but it’s not that it’s intractable. Plus, I can’t use g for backpropagation as it’s not differentiable.

Is it possible to manually replace the derivative of dg with df?

One the one hand I have bilinear sampling which is our f and the other hand since I know that in my case bilinear sampling will only be used for cropping, I can manually crop the region (which I denoted as g).

However, since I want the gradient to go through to allow the model to learn how to crop on its own, I have to use a differentiable function. Note that there are a few more details that make bilinear samping much more expensive than straightforward cropping.

Any insights on this? @albanD

I guess the simplest way is to implement a custom autograd function here:

class CheapF(autograd.Function):
  @staticmethod
  def forward(ctx, *args):
    return g(*args)

  @staticmethod
  def backward(ctx, *grad_outs):
    return f_backward(*grad_outs)

# Use it as
outputs = CheapF.apply(inputs)

Note that you will need to implement the backward of f yourself. You cannot rely on the autograd to do it for you as you never computed the forward operation.