Getting gradients with respect to input features for a batched input

Hi there! I am trying to use torch autograd to get the gradient of the output of a CNN, with respect to the input features. I can do this for a single batch element, but can’t see a way to do this for a batch of inputs.

We take some 28x28 data vector, pass it through a CNN that preserves the input shape, and then project this down to a single scalar using a dot product. I am trying to find the gradient of that scalar, with respect to the 28x28 input features. A simple example below:

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)

# Here's a simple CNN:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1,padding="same")
        self.conv2 = nn.Conv2d(32, 32, 3, 1,padding="same")
        self.conv2 = nn.Conv2d(32, 1, 3, 1,padding="same")

    def forward(self, x):
        if len(x)<4: ## Enable processing of batch_size=1
            x=x.unsqueeze(0)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        return x


if torch.cuda.is_available():
    print("CUDA Available")
    device = torch.device('cuda')
else:
    print('CUDA Not Available')
    device = torch.device('cpu')

batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device,requires_grad=True)
vectors=torch.rand((batch_size,28*28),device=device,requires_grad=True)
model = SimpleCNN().to(device=device)

def get_single_grad(x,vector,model):
    """ Get gradient of our dot product with respect to a single batch element """
    y_single=model(x)
    prod=torch.dot(y_single.view(28*28),vector)
    grads=torch.autograd.grad(prod,x)[0]
    return grads

We can get the grad of the dot product for each batch, by looping in the following way, which works:

%%time
single_grads=[]
for aa in range(len(data)):
    grads=get_single_grad(data[aa],vectors[aa],model)
    single_grads.append(grads)
    
## output:
#CPU times: user 14.8 ms, sys: 1.02 ms, total: 15.9 ms
#Wall time: 15.4 ms

But when I try and vmap this function over the batch dimension:

batched_grad=torch.func.vmap(get_single_grad,in_dims=(0,0,None))
batched_grad(data,vectors,model)

I get the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [59], in <cell line: 2>()
      1 batched_grad=torch.func.vmap(get_single_grad,in_dims=(0,0,None))
----> 2 batched_grad(data,vectors,model)

File /ext3/miniconda3/lib/python3.9/site-packages/torch/_functorch/apis.py:188, in vmap.<locals>.wrapped(*args, **kwargs)
    187def wrapped(*args, **kwargs):
--> 188return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)

File /ext3/miniconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:278, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    274return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
    275                          args_spec, out_dims, randomness, **kwargs)
    277 # If chunk_size is not specified.
--> 278return _flat_vmap(
    279     func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
    280 )

File /ext3/miniconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:44, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     41 @functools.wraps(f)
     42def fn(*args, **kwargs):
     43with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 44return f(*args, **kwargs)

File /ext3/miniconda3/lib/python3.9/site-packages/torch/_functorch/vmap.py:391, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    389try:
    390     batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 391     batched_outputs = func(*batched_inputs, **kwargs)
    392return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    393finally:

Input In [57], in get_single_grad(x, vector, model)
     31 y_single=model(x)
     32 prod=torch.dot(y_single.view(28*28),vector)
---> 33 grads=torch.autograd.grad(prod,x)[0]
     34return grads

File /ext3/miniconda3/lib/python3.9/site-packages/torch/autograd/__init__.py:411, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)
    407     result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
    408         grad_outputs_
    409     )
    410else:
--> 411     result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    412         t_outputs,
    413         grad_outputs_,
    414         retain_graph,
    415         create_graph,
    416         inputs,
    417         allow_unused,
    418         accumulate_grad=False,
    419     )  # Calls into the C++ engine to run the backward pass
    420if materialize_grads:
    421if any(
    422         result[i]isNoneandnot is_tensor_like(inputs[i])
    423for iin range(len(inputs))
    424     ):
    RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Given there are no issues with the single element grad, I’m guessing the dot product operation is not tracking gradients inside of vmap? Any insight/suggestions would be much appreciated.

Hi @Pericules,

There’s a working code below, note When using torch.func methods you need to only use torch.func methods and not mix them with standard torch.autograd.grad operations.

More information of per-sample gradients can be found in the docs here.

Here’s the working code:

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.func import functional_call, vmap, grad

torch.manual_seed(0)

# Here's a simple CNN:
class SimpleCNN(nn.Module):
  def __init__(self):
    super(SimpleCNN, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1,padding="same")
    self.conv2 = nn.Conv2d(32, 32, 3, 1,padding="same")
    self.conv2 = nn.Conv2d(32, 1, 3, 1,padding="same")

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    return x
        
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device,requires_grad=True)
vectors=torch.rand((batch_size,28*28),device=device,requires_grad=True)
model = SimpleCNN().to(device=device)
params = dict(model.named_parameters())
y = model(data)

def calc_output(params, data, vector):
	y_single = functional_call(model, params, data)
	prod = torch.dot(y_single.view(28*28), vector)
	return prod
	
grads = vmap(grad(calc_output, argnums=(1)), in_dims=(None, 0, 0))(params, data, vectors)
print(grads.shape) #returns shape [64, 1, 28, 28]

Thank you! The snippet works, I just realised that it doesn’t fit for my use case, as I need to reuse the model predictions for another component in the loss. The loss function is:

CodeCogsEqn (26)
where v are our random vectors, s_theta is the NN, and I would like to compute this loss across a batch. The problem with the snippet above is that the model is evaluated inside the vmap, so we cannnot use the output for the second loss component.

For a single batch element, this is straightforward - using the variables from your snippet:

Nx=28
data_single=data[0]
vector_single=vectors[0]

y=model(data_single)
## Flatten to vector for dot products
y=y.view(Nx*Nx)
prod=torch.dot(y,vector_single)
## Get gradient with respect to input features
grads=torch.autograd.grad(prod, data_single,retain_graph=True)[0]
grads=grads.view(Nx*Nx)
## Final dot product to get scalar loss value
loss1=torch.dot(vector_single,grads)
## Second
loss2=0.5*torch.sum(y**2)

So in the batched case, I am trying to implement something that would look like:

y=model(data)
## Flatten to batches of vectors, for dot products
y=y.view(batch_size,Nx*Nx)
## Batched dot product
prod=torch.einsum("ij,ij->i",y,vectors)
## Get gradient with respect to input features for each batch element
grads=torch.autograd.grad(prod, data,retain_graph=True)[0]
grads=grads.view(batch_size,Nx*Nx)
## Final dot product to get scalar loss value for each batch element
loss1=torch.einsum("ij,ij->i",vectors,grads)
## Second loss term
loss2=0.5*torch.einsum("ij,ij->i",y,y)

The difficulty is that prod is now a vector quantity, so grad does not work on this. I tried wrapping the single-element case in vmap, calling the model using functional_call:

def single_loss(params, data_single, vector):
    y=functional_call(model, params, data_single)
    y=y.view(Nx*Nx)
    prod=torch.dot(y,vector_single)
    ## Get gradient with respect to input features
    grads=torch.autograd.grad(prod, data_single,retain_graph=True)[0]
    grads=grads.view(Nx*Nx)
    loss1=torch.dot(vector_single,grads)
    loss2=0.5*torch.sum(y**2)
    return loss1,loss2

The other approach would be to use is_grads_batched=True in grad, which I would assume is designed to process batched inputs. This returns an output, however it appears to calculate gradients between batch dimensions, which is extremely inefficient. Am beginning to wonder whether this is possible in pytorch?

You’re mixing torch.autograd.grad with torch.func.grad, as I mentioned above,

Here’s a working version of the code:

You can return intermediate values of a vmap call via using the has_aux flag, which will tell grad, jacrev etc. to ignore taking the gradient of these outputted terms.

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.func import functional_call, vmap, grad

torch.manual_seed(0)

# Here's a simple CNN:
class SimpleCNN(nn.Module):
  def __init__(self):
    super(SimpleCNN, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1,padding="same")
    self.conv2 = nn.Conv2d(32, 32, 3, 1,padding="same")
    self.conv2 = nn.Conv2d(32, 1, 3, 1,padding="same")

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    return x

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device,requires_grad=True)
vectors=torch.rand((batch_size,28*28),device=device,requires_grad=True)
model = SimpleCNN().to(device=device)

params = dict(model.named_parameters())

def calc_output(params, data, vector):
	y_single = functional_call(model, params, data)
	y_single = y_single.view(28*28)
	prod = torch.dot(y_single, vector)
	return prod, y_single
	
grads, y = vmap(grad(calc_output, argnums=(1), has_aux=True), in_dims=(None, 0, 0))(params, data, vectors) #returns shape [64, 1, 28, 28], [64, 784]

grads = grads.reshape(-1, 28*28) #flatten

loss1 = vmap(torch.dot, in_dims=(0,0))(vectors, grads)
loss2 = vmap(torch.linalg.norm, in_dims=(0))(y)

loss = loss1 + loss2 #shape [64,]

#do torch.mean(loss) etc... from here...
1 Like