Efficiently computing the per pixel gradient?

I have a function that takes a single parameter (alpha) as an input an outputs a N by N image (2048x2048). I want to obtain the gradient of this image with respect to the parameter (alpha). I’m not talking about a sobel filter, I’m looking to see how my image changes as I change alpha. The function takes around 2 seconds to evaluate.

grad_img = torch.autograd.functional.jacobian(render_with_alpha, alpha).squeeze(-1) 

This works, and does exactly what I want. However, it takes minutes for a 64x64 image, and I terminated the 1024x1024 after 16h before it finished (I’m looking to compute 2048x2048).

Another approach, which is certainly too slow is use backward() for each pixel. To get this working, I had to rerun the forward pass every time, which makes this method impractical (is there a way around this?).

Two alternative methods from the pytorch documentation appear to be jacrev and jacfwd. In my case, since I have a single input and a large output, it appears that jacfwd would be ideal. However, I cannot get it to work with my code. If I understand correctly, it does not use autograd and when I use it, the code has errors about missing storage.

RuntimeError: Cannot access data pointer of Tensor that doesn’t have storage

I have this error every time I use a view, or use detach(). I do both of these quite often in my code.

So how can I efficiently compute the per pixel gradient of the image?

Hi Martens!

The short story is use forward-mode automatic differentiation.

Backward-mode (regular) autograd computes the gradient of a single scalar output value
with respect to a multi-dimensional set of parameters with one forward / backward pass.
(More generally, it computes the vector-jacobian product for a single “vector” associated
with a non-scalar output tensor.) If you wish to use backward-mode autograd to compute
gradients of multiple elements of some output tensor, you have to pay the price of running
multiple forward / backward passes (even if you have pytorch do this for you under the
hood).

On the other hand, forward-mode autograd computes the vector derivative of an output
tensor with respect to a single scalar input (that is, it computes the derivatives of the
elements of the output tenor with respect to the input) with a single (dual-mode) forward
pass. (More generally, it computes the jacobian-vector product for a single “vector”
associated with a non-scalar input tensor.) If you wish to use forward-mode autograd
to compute the derivative of the output tensor with respect to multiple elements of some
input tensor, you have to pay the price of running multiple dual-mode forward passes.

(As a consequence of this we have the general rule that if you want to compute the
full jacobian of some tensor output with respect to some tensor input, you would prefer
backward-mode when the number of elements of the input is significantly larger than the
number of elements of the output and prefer forward mode when the size of the output
is larger than the size of the input.)

In your case, you wish to compute derivatives of many output values (your pixels) with
respect to a single “input” value (your alpha), so you are likely to get dramatically
improved performance if you use forward-mode autograd with a single dual-mode
forward pass.

Here is an example script:

import torch
print (torch.__version__)

alpha = torch.zeros (1)                        # a single scalar parameter
direc = torch.ones (1)                         # the (trivial) direction for the directional derivative

coeffs = torch.arange (7.)                     # these will be the derivatives with respect to alpha

print ('coeffs = ...')
print (coeffs)

with torch.autograd.forward_ad.dual_level():   # forward-mode context manager
    alpha_dual = torch.autograd.forward_ad.make_dual (alpha, direc)
    output = coeffs * alpha_dual               # (dual version of) vector output
    primal, tangent = torch.autograd.forward_ad.unpack_dual (output)
    print ('tangent = ...')
    print (tangent)                            # vector derivative of output with respect to scalar alpha

And here is its output:

2.6.0+cu126
coeffs = ...
tensor([0., 1., 2., 3., 4., 5., 6.])
tangent = ...
tensor([0., 1., 2., 3., 4., 5., 6.])

Best.

K. Frank

1 Like

Thanks for the answer. Unfortunately, I’m having problems with running my code with this. The code works fine without.

Cell In[5], line 24
     23 alpha_dual = torch.autograd.forward_ad.make_dual(alpha, torch.tensor(1.0)) 
---> 24 output_dual = render_with_alpha(alpha_dual)
     25 primal, tangent = torch.autograd.forward_ad.unpack_dual(output_dual)

Cell In[5], line 6, in render_with_alpha(a)
      4 modified_vertices = vertices.clone()
      5 modified_vertices[:, 1] += a
----> 6 render, _ = render_result_func(modified_vertices, smoothing=1, temperature=1)
      7 return render

Cell In[3], line 21, in render_result_func(new_vertices, smoothing, temperature, view_n)
     19 visibility_start = pc()
     20 with torch.no_grad():
---> 21     visibility_weights = get_visibility_weights(new_vertices, faces, edges, face_pairs_vec, is_boundary_edge, view_n, temperature=temperature)
     23 edge_weights = ocg_weights * visibility_weights
     24 render_res_start = pc()

File ~/diff-oc/diffoc/geometry/visibility.py:89, in get_visibility_weights(vertices, faces, edges, face_pairs_vec, is_boundary_edge, view_n, temperature)
     87 print(f"edge_midpoints.device : {edge_midpoints.device}")
     88 print(f"ray_origins.device : {ray_origins.device}")
---> 89 faces_z[n_rays_arange, fp_0] = -(edge_midpoints[:,2] - ray_origins[:,2])
     90 ray_directions = query_vertex_rep - ray_origins
     91 mesh = trimesh.Trimesh(vertices=vertices.detach().cpu(), faces=faces.detach().cpu(), process=False, use_embree=True)

File ~/anaconda3/envs/diffoc/lib/python3.12/site-packages/torch/utils/_device.py:104, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
    102 if func in _device_constructors() and kwargs.get('device') is None:
    103     kwargs['device'] = self.device
--> 104 return func(*args, **kwargs)

RuntimeError: CUDA error: unknown error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

How would changing from backward differentiation to forward differentiation reveal any CUDA errors? Do I have hopes to debug this without recompiling PyTorch?

Hi Martens!

I don’t have anything to offer about this sort of CUDA error.

You might try running your computation entirely on the cpu. Maybe that would produce
a more informative error message or perhaps the error would go away.

I would also recommend posting the pytorch and cuda versions that you are using in the
hope that someone may have seen this before. If you’re not already doing so, you should
try running this on the latest version of pytorch / cuda (assuming that the latest version runs
on your gpu).

Best.

K. Frank