Autograd row-wise of a tensor using PyTorch autograd and without for loop

I am new in PyTorch and developing a deeponet code. I need to create a loss function where autograd is applied on each sample of model output. I used for loop, but the code is very slow. Is there any way to vectorize the for loop to get the grad for entire samples in one go. I am attaching the loss function code

def derivative(dy: torch.Tensor, x: torch.Tensor, order: int = 1) → torch.Tensor:
“”"
This function calculates the derivative of the model at x_f
“”"
for i in range(order):
dy = torch.autograd.grad(
dy, x, grad_outputs = torch.ones_like(dy), create_graph=True, retain_graph=True
)[0]
return dy

def loss_phy_mod(model: nn.Module, y_data, x_data):
yy=model(y_data,x_data)
jacobian_rows = [derivative(yy[i,:],x_data,1)
for i in range(y_data.shape[0])]
jacobian = torch.stack(jacobian_rows)

  loss=  torch.mean((y_data.flatten() - jacobian.flatten())**2)
  
  return loss

Hi Arka!

Composing vmap with grad will likely do what you want.

Best.

K. Frank

2 Likes

Thanks Frank, I tried using vmap. But still not getting

import functorch

Function to compute gradients of a single output element w.r.t input

def compute_grad(y_pred_elem, x):
grad, = torch.autograd.grad(y_pred_elem, x, create_graph=True, retain_graph=True)
return grad

Vectorize this function to compute gradients for all elements

vectorized_grad = functorch.vmap(compute_grad, in_dims=(0, None))

Compute the Jacobian

jacobian = vectorized_grad(y_pred, x.view(1,100))

getting a error element 0 of tensors does not require grad and does not have a grad_fn. but x having requires_grad True.

Hi @Arka_Roy (and @KFrank),

The reason why it isn’t working with torch.func.vmap is that torch.func.vmap requires the entire process be within its ‘funtionalized’ approach, i.e. you can’t mix torch.autograd operations with torch.func when computing higher derivatives.

You can look at a previous answer I’ve shared on the forums, which focuses on using torch.func to compute the Hessian, here: Efficient computation of Hessian with respect to network weights using autograd.grad and symmetry of Hessian matrix - #8 by AlphaBetaGamma96

Dear @AlphaBetaGamma96,
Here I am sharing a simple code for vectorized grad. One is the sequential approach and other one using vmap. First one working fine but second one not. I am not finding any error. Could you please try and correct it.

import torch

# Device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

def derivative(dy: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    This function calculates the derivative of the model at x_f
    """
    dy = torch.autograd.grad(
            dy, x, grad_outputs=torch.ones_like(dy), create_graph=True, retain_graph=True
                            )[0]
    return dy

def create_func(x: torch.tensor) -> torch.tensor:
    yy = torch.empty(2, x.shape[0])
    y1 = x ** 2
    y2 = x ** 3
    yy[0, :] = y1
    yy[1, :] = y2
    return yy

x = torch.linspace(0, 1, 10).type(torch.float32).to(device)
x.requires_grad = True

yy = create_func(x).to(device)

# Sequential approach to calculate the Jacobian
jacobian_rows = [derivative(yy[i, :], x) for i in range(2)]
jacobian = torch.stack(jacobian_rows)

# Function to compute gradients of a single output element w.r.t input
def compute_grad(y_pred_elem, x):
    grad, = torch.autograd.grad(y_pred_elem, x, create_graph=True, retain_graph=True)[0]
    return grad

# Custom vectorization using torch.vmap
vectorized_grad = torch.vmap(compute_grad, in_dims=(0, None))
jacobian = vectorized_grad(yy, x)

print(jacobian)

Hi @Arka_Roy,

So, you’re mixing torch.autograd with torch.func operations,

which leads to a zero gradient. You need a complete torch.func approach in order for torch.func to return the correct gradient. An example can be found below,

import torch
from torch.func import jacrev, vmap

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def func(x):
  y1=x**2
  y2=x**3
  return torch.stack([y1,y2],dim=-1)

x = torch.linspace(0,1,10, dtype=torch.float32, device=device) #don't need requires_grad=True with torch.func

jacobian = vmap(jacrev(func, argnums=(0)), in_dims=(0))(x)
print(jacobian)

returns,

tensor([[0.0000, 0.0000],
        [0.2222, 0.0370],
        [0.4444, 0.1481],
        [0.6667, 0.3333],
        [0.8889, 0.5926],
        [1.1111, 0.9259],
        [1.3333, 1.3333],
        [1.5556, 1.8148],
        [1.7778, 2.3704],
        [2.0000, 3.0000]], device='cuda:0', grad_fn=<ViewBackward0>)
1 Like

Dear @AlphaBetaGamma96,
Thanks a lot.

Hi, I have a same question, but I need to just get one layers gradient to save some time in my model. But the function ‘jacrev’ can’t let me to choose the exact layer’s parameter to backpropagate. Autograd on a specific layer's parameters This is my full problem.

You’ll need to share a minimal reproducible example of your issue (including the model) so I can understand what’s wrong.

Thank you for your answering. Here is my codes

act=torch.nn.Tanh()
model = nn.Sequential(
            nn.Linear(in_features=3, out_features=5),
            act,
            nn.Linear(in_features=5, out_features=1, bias=False)
        )
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

batch_size = 4
input_features = 3
input_data = torch.randn(batch_size, input_features, device=device)  #init input   
func_params=dict(model.named_parameters())
                
def fm(params,input):
    return functional_call(model,params,input.unsqueeze(0)).squeeze(0)

# loss function                   
def floss(func_params,input):      
    fx=fm(func_params,input)
    return fx

def grad(f,param):
    return torch.autograd.grad(f,param)
per_sample_grads =vmap(jacrev(floss,0), (None, 0))(func_params, input_data)

#i want to get the gradient of ['0.weight']
oweightgrads=vmap(grad(floss,func_params['0.weight']),(None,0))(func_params,input_data)

I’m pretty sure you need to pass all weights to the vmap call, then select the specific weight after the call. As your loss function depends on all the weights, you need to pass all weights to the floss function, even if you just want a subset of the weights.

So, something like this,

import torch 
from torch import nn
from torch.func import vmap, jacrev, functional_call

act=torch.nn.Tanh()
model = nn.Sequential(
            nn.Linear(in_features=3, out_features=5),
            act,
            nn.Linear(in_features=5, out_features=1, bias=False)
        )
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
params=dict(model.named_parameters())

batch_size = 4
input_features = 3
input_data = torch.randn(batch_size, input_features, device=device)  #init input   
                
def fm(params,input):
    return functional_call(model,params,input.unsqueeze(0)).squeeze(0)

def floss(params, input):
  return functional_call(model, params, input)

per_sample_grads = vmap(jacrev(floss, argnums=(0)), in_dims=(None,0))(params,input_data) 
print('per_sample_grads: ',per_sample_grads) # returns PyTree
oweightgrads = per_sample_grads['0.weight']
print('oweightgrads: ',oweightgrads) # returns our select weights (shape: oweightgrads:  torch.Size([4, 1, 5, 3]))

Thanks for your answer. So i can’t save time in this step by just computing some specific weights’ gradient. Is there some other approach? Because I just need one layers gradient in my algorithm.

But your loss depends on the output of the network, which depends on all the weights. I don’t see how you can calculate the gradient of the loss for a specific part of the network without using all the weights?

1 Like

Yeah, but i found this approach works.

gradient=grad(floss(func_params,input_data)[0],func_params['0.weight'])

However, it just works when I choose one batch. Which means I can just use it with for loop.

If this the gradients of the loss (as in averaged over all samples), then taking the gradient with respect to 1 particular weight? Why not just vmap that call?

1 Like

Sorry, I may not catch the point. Do your mean call vmap like this ?

oweightgrads=vmap(grad(floss,func_params['0.weight']),(None,0))(func_params,input_data)

Thank you so much for taking the time to help my question. I’m happy to share that I haved solved it with the following approach. Here is the code:

c=func_params['0.weight']
print(c.shape)
basis_vector=torch.eye(batch_size).to(device)
basis_vector=basis_vector.unsqueeze(-1)
loss=floss(func_params,input_data)
def grad(v):
    return torch.autograd.grad(loss,c,v,retain_graph=True)

g=vmap(grad)(basis_vector)

But when run this code, there has a warning

UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at ../aten/src/ATen/cuda/CublasHandlePool.cpp:135.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass