Trying to explain Zephyr generative LLM

Hi,

I’m trying to run the new captum’s features to explain zephyr, especially LLMAttribution and TextTokenInput to apply FeatureAblation, ShapleyValues and Lime.
I have the following common error message: “RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!”

Here is the related snippet of code (following Captum’s tutorial):

model_name = ‘HuggingFaceH4/zephyr-7b-beta’
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=“cuda”, # dispatch efficiently the model on the available ressources
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt_template = “Dave lives in Palm Coast, FL and is a lawyer. His personal interests include”
target = ‘playing piano’

explainer = FeatureAblation(model)
llm_attr = LLMAttribution(explainer, tokenizer)

inp = TextTokenInput(
prompt_template,
tokenizer,
skip_tokens=[1], # skip the special token for the start of the text
)

attr_res = llm_attr.attribute(inp, target=target)

It works when device_map = ‘cpu’, and it doesn’t work when device_map=‘auto’

Many thanks in advance.

Hi @milanbhan , could you post the full error trace? The error simply means there is one tensor was not in cuda. The whole error log may help us identify the tensor.

Hello @aobo-y , thank you for your prompt reply :grinning:
I precise that I have the same problem using other models such as Phi2 or Orca.
Here is the error message :

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 11
      4 inp = TextTemplateInput(
      5     template="{} lives in {}, {} and is a {}. {} personal interests include", 
      6     values=["Dave", "Palm Coast", "FL", "lawyer", "His"],
      7 )
      9 target = "playing golf, hiking, and cooking."
---> 11 attr_res = llm_attr.attribute(inp, target=target)
     13 attr_res.plot_token_attr(show=True)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py:361, in LLMAttribution.attribute(self, inp, target, num_trials, gen_args, _inspect_forward, **kwargs)
    358 for _ in range(num_trials):
    359     attr_input = inp.to_tensor().to(self.device)
--> 361     cur_attr = self.attr_method.attribute(
    362         attr_input,
    363         additional_forward_args=(inp, target_tokens, _inspect_forward),
    364         **kwargs,
    365     )
    367     # temp necessary due to FA & Shapley's different return shape of multi-task
    368     # FA will flatten output shape internally (n_output_token, n_itp_features)
    369     # Shapley will keep output shape (batch, n_output_token, n_input_features)
    370     cur_attr = cur_attr.reshape(attr.shape)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
     40 @wraps(func)
     41 def wrapper(*args, **kwargs):
---> 42     return func(*args, **kwargs)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/feature_ablation.py:289, in FeatureAblation.attribute(self, inputs, baselines, target, additional_forward_args, feature_mask, perturbations_per_eval, show_progress, **kwargs)
    285     attr_progress.update(0)
    287 # Computes initial evaluation with all features, which is compared
    288 # to each ablated result.
--> 289 initial_eval = self._strict_run_forward(
    290     self.forward_func, inputs, target, additional_forward_args
    291 )
    293 if show_progress:
    294     attr_progress.update()

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/feature_ablation.py:599, in FeatureAblation._strict_run_forward(self, *args, **kwargs)
    593 def _strict_run_forward(self, *args, **kwargs) -> Tensor:
    594     """
    595     A temp wrapper for global _run_forward util to force forward output
    596     type assertion & conversion.
    597     Remove after the strict logic is supported by all attr classes
    598     """
--> 599     forward_output = _run_forward(*args, **kwargs)
    600     if isinstance(forward_output, Tensor):
    601         return forward_output

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/_utils/common.py:531, in _run_forward(forward_func, inputs, target, additional_forward_args)
    528 inputs = _format_inputs(inputs)
    529 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 531 output = forward_func(
    532     *(*inputs, *additional_forward_args)
    533     if additional_forward_args is not None
    534     else inputs
    535 )
    536 return _select_targets(output, target)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py:244, in LLMAttribution._forward_func(self, perturbed_tensor, inp, target_tokens, _inspect_forward)
    242 log_prob_list = []
    243 for target_token in target_tokens:
--> 244     output_logits = self.model.forward(
    245         model_inp, attention_mask=torch.tensor([[1] * model_inp.shape[1]])
    246     )
    247     new_token_logits = output_logits.logits[:, -1]
    248     log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:1044, in MistralForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1041 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1043 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1044 outputs = self.model(
   1045     input_ids=input_ids,
   1046     attention_mask=attention_mask,
   1047     position_ids=position_ids,
   1048     past_key_values=past_key_values,
   1049     inputs_embeds=inputs_embeds,
   1050     use_cache=use_cache,
   1051     output_attentions=output_attentions,
   1052     output_hidden_states=output_hidden_states,
   1053     return_dict=return_dict,
   1054 )
   1056 hidden_states = outputs[0]
   1057 logits = self.lm_head(hidden_states)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:929, in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    919     layer_outputs = self._gradient_checkpointing_func(
    920         decoder_layer.__call__,
    921         hidden_states,
   (...)
    926         use_cache,
    927     )
    928 else:
--> 929     layer_outputs = decoder_layer(
    930         hidden_states,
    931         attention_mask=attention_mask,
    932         position_ids=position_ids,
    933         past_key_value=past_key_values,
    934         output_attentions=output_attentions,
    935         use_cache=use_cache,
    936     )
    938 hidden_states = layer_outputs[0]
    940 if use_cache:

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:654, in MistralDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    651 hidden_states = self.input_layernorm(hidden_states)
    653 # Self Attention
--> 654 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    655     hidden_states=hidden_states,
    656     attention_mask=attention_mask,
    657     position_ids=position_ids,
    658     past_key_value=past_key_value,
    659     output_attentions=output_attentions,
    660     use_cache=use_cache,
    661 )
    662 hidden_states = residual + hidden_states
    664 # Fully Connected

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:297, in MistralAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    292     if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
    293         raise ValueError(
    294             f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
    295         )
--> 297     attn_weights = attn_weights + attention_mask
    299 # upcast attention to fp32
    300 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Thank you very much

Milan

Curious if you have tried Llama and did you encounter this issue? I doubt the error is bug in Captum captum/attr/_core/llm_attr.py that attention_mask is not converted to the same device of the model. Huggingface may have different handlings in different models, as we did not see this issue in Llama.

To fix the issue for now, you may try to move the attention_mask with attention_mask.to(self.device) in captum/attr/_core/llm_attr.py

Hello,
Thank you it works now :slight_smile: