Trying to explain Zephyr generative LLM

Hello @aobo-y , thank you for your prompt reply :grinning:
I precise that I have the same problem using other models such as Phi2 or Orca.
Here is the error message :

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 11
      4 inp = TextTemplateInput(
      5     template="{} lives in {}, {} and is a {}. {} personal interests include", 
      6     values=["Dave", "Palm Coast", "FL", "lawyer", "His"],
      7 )
      9 target = "playing golf, hiking, and cooking."
---> 11 attr_res = llm_attr.attribute(inp, target=target)
     13 attr_res.plot_token_attr(show=True)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py:361, in LLMAttribution.attribute(self, inp, target, num_trials, gen_args, _inspect_forward, **kwargs)
    358 for _ in range(num_trials):
    359     attr_input = inp.to_tensor().to(self.device)
--> 361     cur_attr = self.attr_method.attribute(
    362         attr_input,
    363         additional_forward_args=(inp, target_tokens, _inspect_forward),
    364         **kwargs,
    365     )
    367     # temp necessary due to FA & Shapley's different return shape of multi-task
    368     # FA will flatten output shape internally (n_output_token, n_itp_features)
    369     # Shapley will keep output shape (batch, n_output_token, n_input_features)
    370     cur_attr = cur_attr.reshape(attr.shape)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
     40 @wraps(func)
     41 def wrapper(*args, **kwargs):
---> 42     return func(*args, **kwargs)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/feature_ablation.py:289, in FeatureAblation.attribute(self, inputs, baselines, target, additional_forward_args, feature_mask, perturbations_per_eval, show_progress, **kwargs)
    285     attr_progress.update(0)
    287 # Computes initial evaluation with all features, which is compared
    288 # to each ablated result.
--> 289 initial_eval = self._strict_run_forward(
    290     self.forward_func, inputs, target, additional_forward_args
    291 )
    293 if show_progress:
    294     attr_progress.update()

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/feature_ablation.py:599, in FeatureAblation._strict_run_forward(self, *args, **kwargs)
    593 def _strict_run_forward(self, *args, **kwargs) -> Tensor:
    594     """
    595     A temp wrapper for global _run_forward util to force forward output
    596     type assertion & conversion.
    597     Remove after the strict logic is supported by all attr classes
    598     """
--> 599     forward_output = _run_forward(*args, **kwargs)
    600     if isinstance(forward_output, Tensor):
    601         return forward_output

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/_utils/common.py:531, in _run_forward(forward_func, inputs, target, additional_forward_args)
    528 inputs = _format_inputs(inputs)
    529 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 531 output = forward_func(
    532     *(*inputs, *additional_forward_args)
    533     if additional_forward_args is not None
    534     else inputs
    535 )
    536 return _select_targets(output, target)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py:244, in LLMAttribution._forward_func(self, perturbed_tensor, inp, target_tokens, _inspect_forward)
    242 log_prob_list = []
    243 for target_token in target_tokens:
--> 244     output_logits = self.model.forward(
    245         model_inp, attention_mask=torch.tensor([[1] * model_inp.shape[1]])
    246     )
    247     new_token_logits = output_logits.logits[:, -1]
    248     log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:1044, in MistralForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1041 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1043 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1044 outputs = self.model(
   1045     input_ids=input_ids,
   1046     attention_mask=attention_mask,
   1047     position_ids=position_ids,
   1048     past_key_values=past_key_values,
   1049     inputs_embeds=inputs_embeds,
   1050     use_cache=use_cache,
   1051     output_attentions=output_attentions,
   1052     output_hidden_states=output_hidden_states,
   1053     return_dict=return_dict,
   1054 )
   1056 hidden_states = outputs[0]
   1057 logits = self.lm_head(hidden_states)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:929, in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    919     layer_outputs = self._gradient_checkpointing_func(
    920         decoder_layer.__call__,
    921         hidden_states,
   (...)
    926         use_cache,
    927     )
    928 else:
--> 929     layer_outputs = decoder_layer(
    930         hidden_states,
    931         attention_mask=attention_mask,
    932         position_ids=position_ids,
    933         past_key_value=past_key_values,
    934         output_attentions=output_attentions,
    935         use_cache=use_cache,
    936     )
    938 hidden_states = layer_outputs[0]
    940 if use_cache:

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:654, in MistralDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    651 hidden_states = self.input_layernorm(hidden_states)
    653 # Self Attention
--> 654 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    655     hidden_states=hidden_states,
    656     attention_mask=attention_mask,
    657     position_ids=position_ids,
    658     past_key_value=past_key_value,
    659     output_attentions=output_attentions,
    660     use_cache=use_cache,
    661 )
    662 hidden_states = residual + hidden_states
    664 # Fully Connected

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:297, in MistralAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    292     if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
    293         raise ValueError(
    294             f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
    295         )
--> 297     attn_weights = attn_weights + attention_mask
    299 # upcast attention to fp32
    300 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Thank you very much

Milan