Vectorizing calculation of per-sample gradient for LLMs

I’m trying to vectorize the calculation of per-sample gradients for an LLM using vmap (see the tutorial). However, I’m getting an error (see below) that seems to suggest that this is failing due to data-dependent control flow. Specifically, the issue seems to be due to the if/else treatment of attention masks - is it not possible to use vmap for Transformer models that use attention masks? The vectorization seems to work fine for vision Transformers (with no attention mask).

I’m providing a self-contained minimal reproducer. In order to test that the code is valid in principle, run it with vectorized=False - which runs through. However, with vectorized=True, it gives this error:

RuntimeError: vmap: It looks like you’re either (1) calling .item() on a Tensor or (2) attempting to use a Tensor in some data-dependent control flow or (3) encountering this error in PyTorch internals. For (1): we don’t support vmap over calling .item() on a Tensor, please try to rewrite what you’re doing with other operations. For (2): If you’re doing some control flow instead, we don’t support that yet, please shout over at Data-dependent control flow exploration · Issue #257 · pytorch/functorch · GitHub. For (3): please file an issue.

I’ve tried to simplify the code, but apologies, it is still a bit long. Note that I have simplified it so that it calculates the per-sample gradient of the sum of logits, instead of the gradient of the loss, which still produces the same error.


from datasets import load_dataset
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

def calc_grad_sample_functional_llm(
    model,
    X_ids,
    X_mask,
    y,
    weight_name,
    max_batch_size,
    vectorized=False
):
    """Calculates the gradient of the sum of logits, using the functional interface. LLM version."""
    def _calc_logits(params, buffers, X_ids, X_mask, y):
        predictions = torch.func.functional_call(model, (params, buffers), (X_ids, X_mask))
        logits = predictions.logits
        return logits.sum()
        
    # Define a function that returns the gradient of the logits. 
    def _calc_grad(*args):
        return torch.func.grad(_calc_logits)(*args)[weight_name].view(-1)

    # Vectorize the above function over the sample and target (3rd, 4th, and 5th) args
    # of _calc_logits(), but use the same params and buffers (1st and 2nd) args for all
    # batches.
    calc_grad_sample = torch.func.vmap(_calc_grad, in_dims=(None, None, 0, 0, 0))

    params = {k: v for k, v in model.named_parameters() if k == weight_name}
    buffers = {k: v for k, v in model.named_buffers() if k.startswith(weight_name.rstrip("weight"))}

    grad_sample = []
    for X_ids_batch, X_mask_batch, y_batch in zip(
        torch.split(X_ids, max_batch_size, dim=0),
        torch.split(X_mask, max_batch_size, dim=0),
        torch.split(y, max_batch_size, dim=0),
    ):
        with torch.no_grad():
            if not vectorized:
                # Works! But not vectorized
                for X_ids_sample, X_mask_sample, y_sample in zip(X_ids_batch, X_mask_batch, y_batch):
                    grad_sample_single_sample = _calc_grad(params, buffers, X_ids_sample, X_mask_sample, y_sample)
                    grad_sample.append(grad_sample_single_sample)
            else:
                # Gives an error! Vectorized
                grad_sample_batch = calc_grad_sample(params, buffers, X_ids_batch, X_mask_batch, y_batch)
        
        if vectorized:
            del X_batch, y_batch
            grad_sample.append(grad_sample_batch)

    return torch.vstack(grad_sample)

def load_llm(model_name):
    """Loads an LLM."""
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    return model, tokenizer


def buffer_input(sample, max_seqlen):
    """Fills in missing values in sample, or truncate, as needed."""
    input_ids = sample["input_ids"]
    attention_mask = sample["attention_mask"]
    n_records, n_tokens = input_ids.shape
    if n_tokens > max_seqlen:
        # Truncate down to max_seqlen
        input_ids = input_ids[:, :max_seqlen]
        attention_mask = attention_mask[:, :max_seqlen]
        n_tokens = max_seqlen
    else:
        # Fill with zeros up to max_seqlen
        n_fill = max_seqlen - n_tokens
        filler = torch.zeros(size=(n_records, n_fill), dtype=torch.int64)
        input_ids = torch.cat((input_ids, filler), dim=1)
        attention_mask = torch.cat((attention_mask, filler), dim=1)

    return {"input_ids": input_ids, "attention_mask": attention_mask}, n_tokens


def prep_loader_batch(batch, tokenizer, seqlen):
    """
    Extracts the set of lines in "text", tokenizes the text (converts each string to the
    appropriate token ID via the tokenizer), passes the tokens to buffer_input(),
    collects results, and returns the full collected set.
    """
    in_cache = []
    target_cache = []
    for line in batch["text"]:
        line_tokens = tokenizer(line, return_tensors="pt")
        in_buffered, n_tokens = buffer_input(line_tokens, seqlen)
        in_cache.append(in_buffered)

        targets = in_buffered["input_ids"].clone()
        # Anywhere filler is present in inputs, mask output with -100
        targets[:, n_tokens:] = -100
        target_cache.append(targets)

    X_ids = torch.stack(tuple(sample["input_ids"] for sample in in_cache))
    X_mask = torch.stack(tuple(sample["attention_mask"] for sample in in_cache))
    y = torch.stack(target_cache)

    return X_ids, X_mask, y

weight_name = "model.layers.1.mlp.up_proj.weight"
model, tokenizer = load_llm("stas/tiny-random-llama-2")
train_dataset = load_dataset("stas/c4-en-10k", split="train")
batch = train_dataset[:2] # Get just two samples

seqlen = model.config.max_position_embeddings
X_ids, X_mask, y = prep_loader_batch(batch, tokenizer, seqlen)

for vectorized in [False, True]:
    print(f"{vectorized=}")
    calc_grad_sample_functional_llm(model, X_ids, X_mask, y, weight_name, max_batch_size=10, vectorized=vectorized)
    print("  No errors")

Hi @MootWoot,

Can you share the full stacktrace of the error? Also, why are you computing the gradients within a torch.no_grad() context manager?

Also, can you share the code for the LLM?

Thanks, @AlphaBetaGamma96 - please see the full stacktrace below.

The reason I’m using no_grad() is that I noticed that the code was using much more memory than expected. On a whim, I tried this, and it works. I mean that: 1. it runs - often a good indication in itself on whether it’s ok to use no_grad(), from my experience, 2. it produces the correct result - as validated separately, in unit tests, and 3. it reduces the memory requirements. My hand-wavy uneducated guess is that when using the functional interface with vmap, the gradients are still being kept track of, in some other way, so it’s not necessary to keep track of them in the usual way… But any insight on that would be appreciated.

The code for creating the LLM is here - it’s essentially a tiny LLaMa model with random weights - which is quite convenient for testing.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[4], line 118
    116 for vectorized in [False, True]:
    117     print(f"{vectorized=}")
--> 118     calc_grad_sample_functional_llm(model, X_ids, X_mask, y, weight_name, max_batch_size=10, vectorized=vectorized)
    119     print("  No errors")

Cell In[4], line 48, in calc_grad_sample_functional_llm(model, X_ids, X_mask, y, weight_name, max_batch_size, vectorized)
     45             grad_sample.append(grad_sample_single_sample)
     46     else:
     47         # Give an error! Vectorized
---> 48         grad_sample_batch = calc_grad_sample(params, buffers, X_ids_batch, X_mask_batch, y_batch)
     50 if vectorized:
     51     del X_batch, y_batch

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/_functorch/apis.py:188, in vmap.<locals>.wrapped(*args, **kwargs)
    187 def wrapped(*args, **kwargs):
--> 188     return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/_functorch/vmap.py:278, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    274     return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
    275                          args_spec, out_dims, randomness, **kwargs)
    277 # If chunk_size is not specified.
--> 278 return _flat_vmap(
    279     func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
    280 )

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/_functorch/vmap.py:44, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     41 @functools.wraps(f)
     42 def fn(*args, **kwargs):
     43     with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 44         return f(*args, **kwargs)

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/_functorch/vmap.py:391, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    389 try:
    390     batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 391     batched_outputs = func(*batched_inputs, **kwargs)
    392     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    393 finally:

Cell In[4], line 24, in calc_grad_sample_functional_llm.<locals>._calc_grad(*args)
     23 def _calc_grad(*args):
---> 24     return torch.func.grad(_calc_logits)(*args)[weight_name].view(-1)

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/_functorch/apis.py:363, in grad.<locals>.wrapper(*args, **kwargs)
    361 @functools.wraps(func)
    362 def wrapper(*args, **kwargs):
--> 363     return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py:1295, in grad_impl(func, argnums, has_aux, args, kwargs)
   1293 def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs):
   1294     func = lazy_dynamo_disable(func)
-> 1295     results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
   1296     if has_aux:
   1297         grad, (_, aux) = results

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/_functorch/vmap.py:44, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     41 @functools.wraps(f)
     42 def fn(*args, **kwargs):
     43     with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 44         return f(*args, **kwargs)

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py:1256, in grad_and_value.<locals>.wrapper(*args, **kwargs)
   1253 diff_args = _slice_argnums(args, argnums, as_tuple=False)
   1254 tree_map_(partial(_create_differentiable, level=level), diff_args)
-> 1256 output = func(*args, **kwargs)
   1257 if has_aux:
   1258     if not (isinstance(output, tuple) and len(output) == 2):

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:489, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    487     dynamo_config_ctx.__enter__()
    488 try:
--> 489     return fn(*args, **kwargs)
    490 finally:
    491     set_eval_frame(prior)

Cell In[4], line 18, in calc_grad_sample_functional_llm.<locals>._calc_logits(params, buffers, X_ids, X_mask, y)
     17 def _calc_logits(params, buffers, X_ids, X_mask, y):
---> 18     predictions = torch.func.functional_call(model, (params, buffers), (X_ids, X_mask))
     19     logits = predictions.logits
     20     return logits.sum()

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/_functorch/functional_call.py:143, in functional_call(module, parameter_and_buffer_dicts, args, kwargs, tie_weights, strict)
    137 else:
    138     raise ValueError(
    139         f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, "
    140         f"but got {type(parameter_and_buffer_dicts)}"
    141     )
--> 143 return nn.utils.stateless._functional_call(
    144     module,
    145     parameters_and_buffers,
    146     args,
    147     kwargs,
    148     tie_weights=tie_weights,
    149     strict=strict,
    150 )

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/nn/utils/stateless.py:263, in _functional_call(module, parameters_and_buffers, args, kwargs, tie_weights, strict)
    259     args = (args,)
    260 with _reparametrize_module(
    261     module, parameters_and_buffers, tie_weights=tie_weights, strict=strict
    262 ):
--> 263     return module(*args, **kwargs)

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:1208, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1205 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1207 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1208 outputs = self.model(
   1209     input_ids=input_ids,
   1210     attention_mask=attention_mask,
   1211     position_ids=position_ids,
   1212     past_key_values=past_key_values,
   1213     inputs_embeds=inputs_embeds,
   1214     use_cache=use_cache,
   1215     output_attentions=output_attentions,
   1216     output_hidden_states=output_hidden_states,
   1217     return_dict=return_dict,
   1218     cache_position=cache_position,
   1219 )
   1221 hidden_states = outputs[0]
   1222 if self.config.pretraining_tp > 1:

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:992, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    989 if position_ids is None:
    990     position_ids = cache_position.unsqueeze(0)
--> 992 causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
    994 # embed positions
    995 hidden_states = inputs_embeds

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:1076, in LlamaModel._update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens)
   1071     return None
   1073 if self.config._attn_implementation == "sdpa":
   1074     # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
   1075     # in order to dispatch on Flash Attention 2.
-> 1076     if AttentionMaskConverter._ignore_causal_mask_sdpa(
   1077         attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
   1078     ):
   1079         return None
   1081 dtype, device = input_tensor.dtype, input_tensor.device

File ~/.pyenv/versions/prunes/lib/python3.9/site-packages/transformers/modeling_attn_mask_utils.py:282, in AttentionMaskConverter._ignore_causal_mask_sdpa(attention_mask, inputs_embeds, past_key_values_length, sliding_window)
    278     if tuple(attention_mask.shape) != expected_shape:
    279         raise ValueError(
    280             f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
    281         )
--> 282 elif not is_tracing and torch.all(attention_mask == 1):
    283     if query_length == 1 or key_value_length == query_length:
    284         # For query_length == 1, causal attention and bi-directional attention are the same.
    285         ignore_causal_mask = True

RuntimeError: vmap: It looks like you're either (1) calling .item() on a Tensor or (2) attempting to use a Tensor in some data-dependent control flow or (3) encountering this error in PyTorch internals. For (1): we don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. For (2): If you're doing some control flow instead, we don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . For (3): please file an issue.

Are you using sdpa under the hood within your LLM? The stacktrace shows it fails on the line

most likely when calling the torch.all function, which isn’t supported by torch.func.vmap, I think. And this function is called when the implementation for the attention mask is,

Could you try running the LLM without using sdpa?