Hello @aobo-y , thank you for your prompt reply
I precise that I have the same problem using other models such as Phi2 or Orca.
Here is the error message :
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[5], line 11
4 inp = TextTemplateInput(
5 template="{} lives in {}, {} and is a {}. {} personal interests include",
6 values=["Dave", "Palm Coast", "FL", "lawyer", "His"],
7 )
9 target = "playing golf, hiking, and cooking."
---> 11 attr_res = llm_attr.attribute(inp, target=target)
13 attr_res.plot_token_attr(show=True)
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py:361, in LLMAttribution.attribute(self, inp, target, num_trials, gen_args, _inspect_forward, **kwargs)
358 for _ in range(num_trials):
359 attr_input = inp.to_tensor().to(self.device)
--> 361 cur_attr = self.attr_method.attribute(
362 attr_input,
363 additional_forward_args=(inp, target_tokens, _inspect_forward),
364 **kwargs,
365 )
367 # temp necessary due to FA & Shapley's different return shape of multi-task
368 # FA will flatten output shape internally (n_output_token, n_itp_features)
369 # Shapley will keep output shape (batch, n_output_token, n_input_features)
370 cur_attr = cur_attr.reshape(attr.shape)
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
40 @wraps(func)
41 def wrapper(*args, **kwargs):
---> 42 return func(*args, **kwargs)
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/feature_ablation.py:289, in FeatureAblation.attribute(self, inputs, baselines, target, additional_forward_args, feature_mask, perturbations_per_eval, show_progress, **kwargs)
285 attr_progress.update(0)
287 # Computes initial evaluation with all features, which is compared
288 # to each ablated result.
--> 289 initial_eval = self._strict_run_forward(
290 self.forward_func, inputs, target, additional_forward_args
291 )
293 if show_progress:
294 attr_progress.update()
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/feature_ablation.py:599, in FeatureAblation._strict_run_forward(self, *args, **kwargs)
593 def _strict_run_forward(self, *args, **kwargs) -> Tensor:
594 """
595 A temp wrapper for global _run_forward util to force forward output
596 type assertion & conversion.
597 Remove after the strict logic is supported by all attr classes
598 """
--> 599 forward_output = _run_forward(*args, **kwargs)
600 if isinstance(forward_output, Tensor):
601 return forward_output
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/_utils/common.py:531, in _run_forward(forward_func, inputs, target, additional_forward_args)
528 inputs = _format_inputs(inputs)
529 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 531 output = forward_func(
532 *(*inputs, *additional_forward_args)
533 if additional_forward_args is not None
534 else inputs
535 )
536 return _select_targets(output, target)
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py:244, in LLMAttribution._forward_func(self, perturbed_tensor, inp, target_tokens, _inspect_forward)
242 log_prob_list = []
243 for target_token in target_tokens:
--> 244 output_logits = self.model.forward(
245 model_inp, attention_mask=torch.tensor([[1] * model_inp.shape[1]])
246 )
247 new_token_logits = output_logits.logits[:, -1]
248 log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:1044, in MistralForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1041 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1043 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1044 outputs = self.model(
1045 input_ids=input_ids,
1046 attention_mask=attention_mask,
1047 position_ids=position_ids,
1048 past_key_values=past_key_values,
1049 inputs_embeds=inputs_embeds,
1050 use_cache=use_cache,
1051 output_attentions=output_attentions,
1052 output_hidden_states=output_hidden_states,
1053 return_dict=return_dict,
1054 )
1056 hidden_states = outputs[0]
1057 logits = self.lm_head(hidden_states)
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:929, in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
919 layer_outputs = self._gradient_checkpointing_func(
920 decoder_layer.__call__,
921 hidden_states,
(...)
926 use_cache,
927 )
928 else:
--> 929 layer_outputs = decoder_layer(
930 hidden_states,
931 attention_mask=attention_mask,
932 position_ids=position_ids,
933 past_key_value=past_key_values,
934 output_attentions=output_attentions,
935 use_cache=use_cache,
936 )
938 hidden_states = layer_outputs[0]
940 if use_cache:
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:654, in MistralDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
651 hidden_states = self.input_layernorm(hidden_states)
653 # Self Attention
--> 654 hidden_states, self_attn_weights, present_key_value = self.self_attn(
655 hidden_states=hidden_states,
656 attention_mask=attention_mask,
657 position_ids=position_ids,
658 past_key_value=past_key_value,
659 output_attentions=output_attentions,
660 use_cache=use_cache,
661 )
662 hidden_states = residual + hidden_states
664 # Fully Connected
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.conda/envs/pytorch_env/lib/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:297, in MistralAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
292 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
293 raise ValueError(
294 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
295 )
--> 297 attn_weights = attn_weights + attention_mask
299 # upcast attention to fp32
300 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Thank you very much
Milan