Functorch fail to compute Jacobian when subsetting outputs

I hope to use functorch to calculate the Jacobian matrix of a torch module given some inputs. I am only interested in a certain feature subset’s Jacobian so I subset the module’s output. However, the functorch gives error.

This is a good example witout using output masking.

# test without masks
from torch.func import jacrev, vmap

device = torch.device('cuda:0')

z_dim = 200

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.L1 = nn.Linear(z_dim, 40)
        self.L2 = nn.Linear(40,40)
        self.L3 = nn.Linear(40,40)
        self.L4 = nn.Linear(40,40)
        self.Ln = nn.Linear(40, 3004)

    def forward(self, z):
        z = torch.relu(self.L1(z))
        z = torch.relu(self.L2(z))
        z = torch.relu(self.L3(z))
        z = torch.relu(self.L4(z))
        return self.Ln(z)
    

N = 1110
z = torch.randn([N, z_dim]).to(device)

fn = Decoder().to(device)

jacobian_fn = vmap(jacrev(fn))
jacobians = jacobian_fn(z)

This is the subset version, which fails.

# test with masks

# Create a tensor of 3000 zeros
feature_mask = torch.zeros(3004, dtype=torch.long)
# Randomly select 170 indices to be 1
ones_indices = torch.randperm(3004)[:170]
# Set those indices to 1
feature_mask[ones_indices] = 1
# Shuffle the tensor to randomize the positions of 1s
feature_mask = feature_mask[torch.randperm(3004)].to(torch.bool)


from torch.func import jacrev, vmap

z_dim = 200

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.L1 = nn.Linear(z_dim, 40)
        self.L2 = nn.Linear(40,40)
        self.L3 = nn.Linear(40, 3004)

    def forward(self, z):
        z = torch.relu(self.L1(z))
        z = torch.relu(self.L2(z))
        return self.L3(z)[:, feature_mask] # it will fail after adding feature mask.
        #return self.L3(z)[:, [1,2,4,5,3]] # this will fail either.

fn = Decoder()

N = 1110
z     = torch.randn([N, z_dim])

jacobian_fn = vmap(jacrev(fn))
jacobians = jacobian_fn(z)

The error message is:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[4], line 37
     33 z     = torch.randn([N, z_dim])
     36 jacobian_fn = vmap(jacrev(fn))
---> 37 jacobians = jacobian_fn(z)

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/apis.py:188, in vmap.<locals>.wrapped(*args, **kwargs)
    187 def wrapped(*args, **kwargs):
--> 188     return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:278, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    274     return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
    275                          args_spec, out_dims, randomness, **kwargs)
    277 # If chunk_size is not specified.
--> 278 return _flat_vmap(
    279     func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
    280 )

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:44, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     41 @functools.wraps(f)
     42 def fn(*args, **kwargs):
     43     with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 44         return f(*args, **kwargs)

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:391, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    389 try:
    390     batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 391     batched_outputs = func(*batched_inputs, **kwargs)
    392     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    393 finally:

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py:500, in jacrev.<locals>.wrapper_fn(*args)
    497 @wraps(func)
    498 def wrapper_fn(*args):
    499     error_if_complex("jacrev", args, is_input=True)
--> 500     vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
    501     if has_aux:
    502         output, vjp_fn, aux = vjp_out

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:44, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     41 @functools.wraps(f)
     42 def fn(*args, **kwargs):
     43     with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 44         return f(*args, **kwargs)

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py:302, in _vjp_with_argnums(func, argnums, has_aux, *primals)
    300     diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
    301     tree_map_(partial(_create_differentiable, level=level), diff_primals)
--> 302 primals_out = func(*primals)
    304 if has_aux:
    305     if not (isinstance(primals_out, tuple) and len(primals_out) == 2):

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

Cell In[4], line 27, in Decoder.forward(self, z)
     25 z = torch.relu(self.L1(z))
     26 z = torch.relu(self.L2(z))
---> 27 return self.L3(z)[:, feature_mask]

IndexError: too many indices for tensor of dimension 1

I hope to subset inside the function instead of subset the output results because I want to save GPU memory. The Jacobian operation is extremely GPU memory-consuming. I want to load more samples in one calculation.

Can somebody help me figure out why? Thank you!

Hi @Sijie_Chen,

When you vectorize a function, it no longer sees the batch-dim, so that’s why it says ‘too many indices for tensor of dimension 1’, the batch-dim no longer exists (only the output dimension) . So, you could replace [:, feature_mask] with [..., feature_mask] or take the mask after vectorizing the function call.

Also, just a word of advice, when computing jacrev for function, you need to ‘functionalize’ your model. (So that the parameters become an input to the model).

So replace jacobian_fn = vmap(jacrev(fn)) with,

params = dict(fn.named_parameters()) #params need to be an input

#For convience, I define the 'functionalize' call as a lambda function
fcall = lambda params, inputs: torch.func.functional_call(fn, params, inputs)

#Then we vmap/jacrev over the functionalize version of our Decoder
jacobian_fn = vmap(jacrev(fcall, argnums=(1)), in_dims=(None,0))(params, z)

#Here we specify the jacrev is against `z` and that vmap vectorize over the batch dim of the inputs (there is no batch dim for the parameters)
1 Like