I’m trying to vectorize the calculation of per-sample gradients for an LLM using vmap (see the tutorial). However, I’m getting an error (see below) that seems to suggest that this is failing due to data-dependent control flow. Specifically, the issue seems to be due to the if/else treatment of attention masks - is it not possible to use vmap for Transformer models that use attention masks? The vectorization seems to work fine for vision Transformers (with no attention mask).
I’m providing a self-contained minimal reproducer. In order to test that the code is valid in principle, run it with vectorized=False
- which runs through. However, with vectorized=True
, it gives this error:
RuntimeError: vmap: It looks like you’re either (1) calling .item() on a Tensor or (2) attempting to use a Tensor in some data-dependent control flow or (3) encountering this error in PyTorch internals. For (1): we don’t support vmap over calling .item() on a Tensor, please try to rewrite what you’re doing with other operations. For (2): If you’re doing some control flow instead, we don’t support that yet, please shout over at Data-dependent control flow exploration · Issue #257 · pytorch/functorch · GitHub. For (3): please file an issue.
I’ve tried to simplify the code, but apologies, it is still a bit long. Note that I have simplified it so that it calculates the per-sample gradient of the sum of logits, instead of the gradient of the loss, which still produces the same error.
from datasets import load_dataset
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
def calc_grad_sample_functional_llm(
model,
X_ids,
X_mask,
y,
weight_name,
max_batch_size,
vectorized=False
):
"""Calculates the gradient of the sum of logits, using the functional interface. LLM version."""
def _calc_logits(params, buffers, X_ids, X_mask, y):
predictions = torch.func.functional_call(model, (params, buffers), (X_ids, X_mask))
logits = predictions.logits
return logits.sum()
# Define a function that returns the gradient of the logits.
def _calc_grad(*args):
return torch.func.grad(_calc_logits)(*args)[weight_name].view(-1)
# Vectorize the above function over the sample and target (3rd, 4th, and 5th) args
# of _calc_logits(), but use the same params and buffers (1st and 2nd) args for all
# batches.
calc_grad_sample = torch.func.vmap(_calc_grad, in_dims=(None, None, 0, 0, 0))
params = {k: v for k, v in model.named_parameters() if k == weight_name}
buffers = {k: v for k, v in model.named_buffers() if k.startswith(weight_name.rstrip("weight"))}
grad_sample = []
for X_ids_batch, X_mask_batch, y_batch in zip(
torch.split(X_ids, max_batch_size, dim=0),
torch.split(X_mask, max_batch_size, dim=0),
torch.split(y, max_batch_size, dim=0),
):
with torch.no_grad():
if not vectorized:
# Works! But not vectorized
for X_ids_sample, X_mask_sample, y_sample in zip(X_ids_batch, X_mask_batch, y_batch):
grad_sample_single_sample = _calc_grad(params, buffers, X_ids_sample, X_mask_sample, y_sample)
grad_sample.append(grad_sample_single_sample)
else:
# Gives an error! Vectorized
grad_sample_batch = calc_grad_sample(params, buffers, X_ids_batch, X_mask_batch, y_batch)
if vectorized:
del X_batch, y_batch
grad_sample.append(grad_sample_batch)
return torch.vstack(grad_sample)
def load_llm(model_name):
"""Loads an LLM."""
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
return model, tokenizer
def buffer_input(sample, max_seqlen):
"""Fills in missing values in sample, or truncate, as needed."""
input_ids = sample["input_ids"]
attention_mask = sample["attention_mask"]
n_records, n_tokens = input_ids.shape
if n_tokens > max_seqlen:
# Truncate down to max_seqlen
input_ids = input_ids[:, :max_seqlen]
attention_mask = attention_mask[:, :max_seqlen]
n_tokens = max_seqlen
else:
# Fill with zeros up to max_seqlen
n_fill = max_seqlen - n_tokens
filler = torch.zeros(size=(n_records, n_fill), dtype=torch.int64)
input_ids = torch.cat((input_ids, filler), dim=1)
attention_mask = torch.cat((attention_mask, filler), dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask}, n_tokens
def prep_loader_batch(batch, tokenizer, seqlen):
"""
Extracts the set of lines in "text", tokenizes the text (converts each string to the
appropriate token ID via the tokenizer), passes the tokens to buffer_input(),
collects results, and returns the full collected set.
"""
in_cache = []
target_cache = []
for line in batch["text"]:
line_tokens = tokenizer(line, return_tensors="pt")
in_buffered, n_tokens = buffer_input(line_tokens, seqlen)
in_cache.append(in_buffered)
targets = in_buffered["input_ids"].clone()
# Anywhere filler is present in inputs, mask output with -100
targets[:, n_tokens:] = -100
target_cache.append(targets)
X_ids = torch.stack(tuple(sample["input_ids"] for sample in in_cache))
X_mask = torch.stack(tuple(sample["attention_mask"] for sample in in_cache))
y = torch.stack(target_cache)
return X_ids, X_mask, y
weight_name = "model.layers.1.mlp.up_proj.weight"
model, tokenizer = load_llm("stas/tiny-random-llama-2")
train_dataset = load_dataset("stas/c4-en-10k", split="train")
batch = train_dataset[:2] # Get just two samples
seqlen = model.config.max_position_embeddings
X_ids, X_mask, y = prep_loader_batch(batch, tokenizer, seqlen)
for vectorized in [False, True]:
print(f"{vectorized=}")
calc_grad_sample_functional_llm(model, X_ids, X_mask, y, weight_name, max_batch_size=10, vectorized=vectorized)
print(" No errors")