# Functorch fail to compute Jacobian when subsetting outputs

I hope to use functorch to calculate the Jacobian matrix of a torch module given some inputs. I am only interested in a certain feature subset’s Jacobian so I subset the module’s output. However, the functorch gives error.

This is a good example witout using output masking.

``````# test without masks
from torch.func import jacrev, vmap

device = torch.device('cuda:0')

z_dim = 200

class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.L1 = nn.Linear(z_dim, 40)
self.L2 = nn.Linear(40,40)
self.L3 = nn.Linear(40,40)
self.L4 = nn.Linear(40,40)
self.Ln = nn.Linear(40, 3004)

def forward(self, z):
z = torch.relu(self.L1(z))
z = torch.relu(self.L2(z))
z = torch.relu(self.L3(z))
z = torch.relu(self.L4(z))
return self.Ln(z)

N = 1110
z = torch.randn([N, z_dim]).to(device)

fn = Decoder().to(device)

jacobian_fn = vmap(jacrev(fn))
jacobians = jacobian_fn(z)
``````

This is the subset version, which fails.

``````# test with masks

# Create a tensor of 3000 zeros
# Randomly select 170 indices to be 1
ones_indices = torch.randperm(3004)[:170]
# Set those indices to 1
# Shuffle the tensor to randomize the positions of 1s

from torch.func import jacrev, vmap

z_dim = 200

class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.L1 = nn.Linear(z_dim, 40)
self.L2 = nn.Linear(40,40)
self.L3 = nn.Linear(40, 3004)

def forward(self, z):
z = torch.relu(self.L1(z))
z = torch.relu(self.L2(z))
#return self.L3(z)[:, [1,2,4,5,3]] # this will fail either.

fn = Decoder()

N = 1110
z     = torch.randn([N, z_dim])

jacobian_fn = vmap(jacrev(fn))
jacobians = jacobian_fn(z)
``````

The error message is:

``````---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[4], line 37
33 z     = torch.randn([N, z_dim])
36 jacobian_fn = vmap(jacrev(fn))
---> 37 jacobians = jacobian_fn(z)

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/apis.py:188, in vmap.<locals>.wrapped(*args, **kwargs)
187 def wrapped(*args, **kwargs):
--> 188     return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:278, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
274     return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
275                          args_spec, out_dims, randomness, **kwargs)
277 # If chunk_size is not specified.
--> 278 return _flat_vmap(
279     func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
280 )

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:44, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
41 @functools.wraps(f)
42 def fn(*args, **kwargs):
---> 44         return f(*args, **kwargs)

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:391, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
389 try:
390     batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 391     batched_outputs = func(*batched_inputs, **kwargs)
392     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
393 finally:

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py:500, in jacrev.<locals>.wrapper_fn(*args)
497 @wraps(func)
498 def wrapper_fn(*args):
499     error_if_complex("jacrev", args, is_input=True)
--> 500     vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
501     if has_aux:
502         output, vjp_fn, aux = vjp_out

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:44, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
41 @functools.wraps(f)
42 def fn(*args, **kwargs):
---> 44         return f(*args, **kwargs)

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py:302, in _vjp_with_argnums(func, argnums, has_aux, *primals)
300     diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
301     tree_map_(partial(_create_differentiable, level=level), diff_primals)
--> 302 primals_out = func(*primals)
304 if has_aux:
305     if not (isinstance(primals_out, tuple) and len(primals_out) == 2):

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518         or _global_backward_pre_hooks or _global_backward_hooks
1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
1522 try:
1523     result = None

Cell In[4], line 27, in Decoder.forward(self, z)
25 z = torch.relu(self.L1(z))
26 z = torch.relu(self.L2(z))

IndexError: too many indices for tensor of dimension 1
``````

I hope to subset inside the function instead of subset the output results because I want to save GPU memory. The Jacobian operation is extremely GPU memory-consuming. I want to load more samples in one calculation.

Can somebody help me figure out why? Thank you!

Hi @Sijie_Chen,

When you vectorize a function, it no longer sees the batch-dim, so that’s why it says ‘too many indices for tensor of dimension 1’, the batch-dim no longer exists (only the output dimension) . So, you could replace `[:, feature_mask]` with `[..., feature_mask]` or take the mask after vectorizing the function call.

Also, just a word of advice, when computing `jacrev` for function, you need to ‘functionalize’ your model. (So that the parameters become an input to the model).

So replace `jacobian_fn = vmap(jacrev(fn))` with,

``````params = dict(fn.named_parameters()) #params need to be an input

#For convience, I define the 'functionalize' call as a lambda function
fcall = lambda params, inputs: torch.func.functional_call(fn, params, inputs)

#Then we vmap/jacrev over the functionalize version of our Decoder
jacobian_fn = vmap(jacrev(fcall, argnums=(1)), in_dims=(None,0))(params, z)

#Here we specify the jacrev is against `z` and that vmap vectorize over the batch dim of the inputs (there is no batch dim for the parameters)``````
1 Like