Is it possible to access and edit the graph on the go?
What I mean is the following:
Say, we have variables a and b and we can compute c as a function of a and b, i.e. c = f(a, b). For the forward propagation we need to know a, b and the function and for the backward we need the derivative of the function and then any partial derivatives coming that way.
What if, the function f was too expensive and instead of calculating c based on the function f, we had a way around, say function g. Now, we can’t use function g in our computations as g is not differentiable whereas f is. Is there any way to embed this output c (based on g) for the forward propagation but still use f during backward propagation?
I hope the explanation is clear enough. Let me know if it’s not!
If the forward computation of f is too expensive, why is it’s backward feasible? If you use the autograd, the backward will have a similar runtime. You have a simple formula for it?
If the forward computation of f is too expensive, why is it’s backward feasible? If you use the autograd, the backward will have a similar runtime. You have a simple formula for it?
F is expensive (both computationally and memory wise) but it’s not that it’s intractable. Plus, I can’t use g for backpropagation as it’s not differentiable.
Is it possible to manually replace the derivative of dg with df?
One the one hand I have bilinear sampling which is our f and the other hand since I know that in my case bilinear sampling will only be used for cropping, I can manually crop the region (which I denoted as g).
However, since I want the gradient to go through to allow the model to learn how to crop on its own, I have to use a differentiable function. Note that there are a few more details that make bilinear samping much more expensive than straightforward cropping.
I guess the simplest way is to implement a custom autograd function here:
class CheapF(autograd.Function):
@staticmethod
def forward(ctx, *args):
return g(*args)
@staticmethod
def backward(ctx, *grad_outs):
return f_backward(*grad_outs)
# Use it as
outputs = CheapF.apply(inputs)
Note that you will need to implement the backward of f yourself. You cannot rely on the autograd to do it for you as you never computed the forward operation.