I hope to use functorch to calculate the Jacobian matrix of a torch module given some inputs. I am only interested in a certain feature subset’s Jacobian so I subset the module’s output. However, the functorch gives error.
This is a good example witout using output masking.
# test without masks
from torch.func import jacrev, vmap
device = torch.device('cuda:0')
z_dim = 200
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.L1 = nn.Linear(z_dim, 40)
self.L2 = nn.Linear(40,40)
self.L3 = nn.Linear(40,40)
self.L4 = nn.Linear(40,40)
self.Ln = nn.Linear(40, 3004)
def forward(self, z):
z = torch.relu(self.L1(z))
z = torch.relu(self.L2(z))
z = torch.relu(self.L3(z))
z = torch.relu(self.L4(z))
return self.Ln(z)
N = 1110
z = torch.randn([N, z_dim]).to(device)
fn = Decoder().to(device)
jacobian_fn = vmap(jacrev(fn))
jacobians = jacobian_fn(z)
This is the subset version, which fails.
# test with masks
# Create a tensor of 3000 zeros
feature_mask = torch.zeros(3004, dtype=torch.long)
# Randomly select 170 indices to be 1
ones_indices = torch.randperm(3004)[:170]
# Set those indices to 1
feature_mask[ones_indices] = 1
# Shuffle the tensor to randomize the positions of 1s
feature_mask = feature_mask[torch.randperm(3004)].to(torch.bool)
from torch.func import jacrev, vmap
z_dim = 200
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.L1 = nn.Linear(z_dim, 40)
self.L2 = nn.Linear(40,40)
self.L3 = nn.Linear(40, 3004)
def forward(self, z):
z = torch.relu(self.L1(z))
z = torch.relu(self.L2(z))
return self.L3(z)[:, feature_mask] # it will fail after adding feature mask.
#return self.L3(z)[:, [1,2,4,5,3]] # this will fail either.
fn = Decoder()
N = 1110
z = torch.randn([N, z_dim])
jacobian_fn = vmap(jacrev(fn))
jacobians = jacobian_fn(z)
The error message is:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[4], line 37
33 z = torch.randn([N, z_dim])
36 jacobian_fn = vmap(jacrev(fn))
---> 37 jacobians = jacobian_fn(z)
File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/apis.py:188, in vmap.<locals>.wrapped(*args, **kwargs)
187 def wrapped(*args, **kwargs):
--> 188 return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:278, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
274 return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
275 args_spec, out_dims, randomness, **kwargs)
277 # If chunk_size is not specified.
--> 278 return _flat_vmap(
279 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
280 )
File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:44, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
41 @functools.wraps(f)
42 def fn(*args, **kwargs):
43 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 44 return f(*args, **kwargs)
File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:391, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
389 try:
390 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 391 batched_outputs = func(*batched_inputs, **kwargs)
392 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
393 finally:
File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py:500, in jacrev.<locals>.wrapper_fn(*args)
497 @wraps(func)
498 def wrapper_fn(*args):
499 error_if_complex("jacrev", args, is_input=True)
--> 500 vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
501 if has_aux:
502 output, vjp_fn, aux = vjp_out
File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/vmap.py:44, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
41 @functools.wraps(f)
42 def fn(*args, **kwargs):
43 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 44 return f(*args, **kwargs)
File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py:302, in _vjp_with_argnums(func, argnums, has_aux, *primals)
300 diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
301 tree_map_(partial(_create_differentiable, level=level), diff_primals)
--> 302 primals_out = func(*primals)
304 if has_aux:
305 if not (isinstance(primals_out, tuple) and len(primals_out) == 2):
File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/micromamba/envs/torch2/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
Cell In[4], line 27, in Decoder.forward(self, z)
25 z = torch.relu(self.L1(z))
26 z = torch.relu(self.L2(z))
---> 27 return self.L3(z)[:, feature_mask]
IndexError: too many indices for tensor of dimension 1
I hope to subset inside the function instead of subset the output results because I want to save GPU memory. The Jacobian operation is extremely GPU memory-consuming. I want to load more samples in one calculation.
Can somebody help me figure out why? Thank you!