Evaluating gradients of output variables w.r.t parameters for pixelwise models

Hi.
Consider I have a ViT NN trained to perform segmentation which receives an input x of shape (batch_size, n_channels, n_timesteps, height, width) and outputs a prediction y of shape (batch_size, n_classes, height, width). Let us say that for each pixel p in the output image I want to evaluate the gradient with relation to all the model parameters w as Grad(y_p, w) = \nabla_w NN(x; w)_p, with final shape (batch_size, n_classes, n_parameters).
I would like to know what is the best approach to properly evaluate it using the frameworks available on PyTorch.

1 Like

Hi!
Maybe the most straightforward way to do this is to iterate over the batch elements, over the classes and over the pixels, to get scalar tensors, from which you can compute the gradient with respect to the model’s parameters. I’m thinking of something like this:

output: Tensor  # output.shape: [batch_size, n_classes, height, width]
for output_element in output:  # output_element.shape: [n_classes, height, width]
    for class_output in output_element:  # class_output.shape: [height, width]
        for y_row in class_output:  # y_row.shape: [width]
            for pixel in y_row:  # pixel.shape: [] (scalar)
                grads = torch.autograd.grad(pixel, model.parameters(), retain_graph=True)
                # grads will be a tuple with a many tensors as there are model parameters;
                # do whatever you need with grads

Obviously this will be very inefficient in terms of computational time, as it will require batch_size * n_classes * height * width backard passes through your model. If you also store the grads, it will require batch_size * n_classes * height * width times the size of your model in memory (which is most likely prohibitive, even with a small model).

There are ways to speed up the computation by parallelizing some of the backward passes using torch.vmap, or by using e.g. torch.autograd.functional.jacobian — PyTorch 2.7 documentation, but it will be way too memory intensive to compute all of the gradients in parallel. So the way you want to parallelize this computation depends on what you intend to do with your grads. If some of them could be summed together for instance, you could compute the result much more efficiently than what I showed.

Could you maybe provide more context?

Hi, @valerian.rey, thank you the answer. I’m interested in using the metric seen here: IEEE Xplore Full-Text PDF:, expression 4.