Speeding up gradient computation; instead of using a for loop

I have a torch tensor named HSNR of size 4096. I am calculating the gradient of it w.r.t. NN parameters in this approach. However the process is extremely slow. How can I speed it up ?

for i in range(4096):

     first_gradient_list[i] = torch.autograd.grad(HSNR[i], NN.parameters(),retain_graph=True). 

It might be hard to answer this how you were hoping without more information, but I’d suggest looking into the functorch documentation for this. It’s a semi-recent addition as torch.func that enables faster and more math-like differentiation.

You could try here for the actual documentation:
https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html#per-sample-grads-the-efficient-way-using-functorch

I created a function like this, where b is batch dimension, d is dimension of the input to the forward call, and n is the number of parameters.

    def dn_dtheta(self, x):
        """(b, d) -> (b, n)"""
        dn_dtheta_dict = torch.func.jacrev(self._forward_w_params)(self.get_theta(), x)
        dn_dtheta = dict_to_vec_fl(dn_dtheta_dict)
        return dn_dtheta

in my nn.Module, where I use

    def _forward_w_params(self, theta, x):
        """(n, d) -> (b)"""
        return torch.func.functional_call(model, theta, x)

although my forward function returns just a scalar for each batch (b,d) -> (b).

1 Like

Thanks.

The problem domain is : solving burgers equations with Physics Informed Neural Networks.

Following the documentation I am trying to solve the problem like this. However, Getting the error : “RuntimeError: One of the differentiated Tensors does not require grad”

    data = torch.hstack((X,T))
    # Define a function to compute the loss for a single input 
    def compute_loss_stateless_model(params, buffers, sample):

        batch = sample.unsqueeze(0)

        u_pred = fmodel(params, buffers, batch) 

        # print(sample)
        print("Samples Req Grad : ",sample.requires_grad)
        
        print(batch.requires_grad)

        u_t = torch.autograd.grad(
            u_pred, batch[0][1],
            grad_outputs=torch.ones_like(u_pred),
            retain_graph=True,
            create_graph=True)[0]
    
        u_x = torch.autograd.grad(
            u_pred, batch[0][0], 
            grad_outputs=torch.ones_like(u_pred),
            retain_graph=True,
            create_graph=True)[0]
        
        u_xx = torch.autograd.grad(
            u_x, batch[0][0], 
            grad_outputs=torch.ones_like(u_x),
            retain_graph=True,
            create_graph=True)[0]

        residual = u_t + (u_pred * u_x) - (nu * u_xx)

        # Calculating the Half of Squared Euclidean Norm : half_squaredENorm_residual HSNR
        HSNR = 0.5 * (residual ** 2)
        
        return HSNR
    
    ft_compute_grad = grad(compute_loss_stateless_model) 

    ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))

    ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data)

To debug I found : inside the compute_loss_stateless_model function the requires_grad flag of the sample becomes False. Is there any way around ?

Hi @rajoy99,

Expanding on @Grant_Norman’s answer, you should ideally only use torch.func functions when computing per-sample gradients and not mix torch.func calls with torch.autograd calls (at least not for the calculation of the per-sample gradient), that’s why you get a .grad not required error.

Within the torch.func namespace when you vectorize over a function, it will no longer see the batch dimension and effectively redefines your function to work for a ‘single’ sample and then computes all these samples in parallel, so you need to remove your indexing on the input Tensor as well (torch.func can’t see the batch dim!)

Also, as you have noticed the arguments passed to torch.func calls may have their .grad flags disabled as it returns the gradients directly rather than populating the .grad attribute.

To see all the changes between functorch and torch.func, have a look at the migrating from functorch to torch.func docs here.

Your code above should be re-written to something like this,

net = Model(*args, **kwargs) #nn.Module

params = dict(net.named_parameters()) #params are now dicts in torch.func (rather than lists in functorch)

from torch.func import functional_call, vmap, jacrev, grad 

def fcall(params, sample): #functionalized call (with params as input)
  return functional_call(net, params, sample)

def compute_loss(params, sample):
  u_pred = fcall(params, sample)

  u_grads = grad(fcall, argnums=(1))(params, sample) #assume model outputs scalar
  #u_t and u_x can be spilt separately (depending on their shape and if they're independent of each other)

  u_xx =  jacrev(jacrev(fcall, argnums=(1)), argnums=(1))(params, sample) #reverse-over-reverse Hessian calc. 
  
  residual = u_t + (u_pred*u_x) - (nu*u_xx) #sum (be careful of the shapes here, you have a matrix and vector and try to avoid broadcasting issues).
  
  HSNR = 0.5 * (residual**2)
  return HSNR

losses = vmap(compute_loss, in_dims=(None,0))(params, samples)

Now granted, the example above is for a model with only params and not buffers, but changing that should be minimal (I just don’t know how your model is defined, but something like the torch.func.stack_module_state can give an example of how to combine buffers with params when vmap-ing.

You’ll then need to modify the grad and jacrev calls to include the buffers object (but that won’t be too difficult).

1 Like

Thank you. I made some modifications and it solved my problem of per sample gradient computation. I have verified the computations are correct by comparing with my previous brute force solution.

Still, I have a question. Since I am computing the gradient wrt NN parameters, I get the result as a dictionary of tensors.

HSNR_grad.keys()
>>> dict_keys(['0.weight', '0.bias', '2.weight'.....]) 
HSNR_grad['2.weight'].shape 
>>> torch.Size([4096, 20, 20])

From which I construct the list_of_gradients by running a for loop over the dictionary. Like this :


   keylist = HSNR_grad.keys()

   first_gradient_list = []

   for i in range(4096):
       gradients_at_index =[]

       for k in keylist:
           gradients_at_index.append(HSNR_grad[k][i])
       first_gradient_list.append(gradients_at_index)

Is there any approach which will allow me to compute the gradient in a per sample list format (where each entry is the gradient for that particular sample). It will be more efficient than running the for loop again.

You could try using torch.unbind to convert the batch dim to a list, docs here

1 Like