Why do captum's perturbation and IG treat input & target differently?

I’ve been successfull using captum’s LayerIntegratedGradients class,
but none of my attempts trying the same sorts of inputs and targets using
LLMAttribution seem to work.

I’m working with a BertForMultipleChoice model, and the input is a list of
the repeated prompt followed by the choices:

for i,c in enumerate(tst['input_ids'][0]):
    indices = c.detach().tolist()
    sepIdx =  indices.index(SEP_IDX)
    nearSep = indices[sepIdx-prefix:]
    preTokens = tokenizer.convert_ids_to_tokens(indices[sepIdx-prefix:sepIdx-1])
    choiceTokens = tokenizer.convert_ids_to_tokens(indices[sepIdx+1:])
    print(f"{i} {sepIdx} {' '.join(preTokens):>55}\t[SEP] {' '.join(choiceTokens)}")

	0 120           behalf of fake charities . as webster sees it	[SEP] recognizing the guidelines commentary is authoritative [SEP]
	1 109   when he sol ##ici ##ted personal information from the	[SEP] holding that a sentencing guide ##line pre ##va ##ils over its commentary if the two are inconsistent [SEP]
	2 98                                       l ( b ) ( 9 ) ( a	[SEP] holding that sentencing guidelines commentary must be given controlling weight unless it violate ##s the constitution or a federal statute or is plainly inconsistent with the guidelines itself [SEP]
	3 99                                       ( b ) ( 9 ) ( a )	[SEP] holding that commentary is not authoritative if it is inconsistent with or a plainly er ##rone ##ous reading of the guide ##line it interpret ##s or explains [SEP]
	4 119           on behalf of fake charities . as webster sees	[SEP] holding that guidelines commentary is generally authoritative [SEP]

I’m using LayerIntegratedGradients with a test example and a target scalar representing the index of the correct (multiple choice) like this:

    tstEGTuple = (tst['input_ids'], 
                  tst['attention_mask'], 
                  tst['token_type_ids'])
    targetIdx = 3 # for this particular test example
        
    lig = LayerIntegratedGradients(custForwardModel, model.bert.embeddings)
    attributions_ig = lig.attribute(tstEGTuple, n_steps=5,target=targetIdx) 

and that works, eg allowing calculations like summarize_attributions(attributions_ig), viz.VisualizationDataRecord() etc.

For LLMAttribution I am following the Llama2 tutorial The closest I can get with LLMAttribution seems to require use of TextTokenInput for input, but raw text for the target?

    in0 = tst['input_ids'][0][0]
    in0_tokens = tokenizer.convert_ids_to_tokens(in0)
    in0Txt = ' '.join(in0_tokens)
    in4captum = TextTokenInput(in0Txt, tokenizer,skip_tokens=skip_tokens)

    target = targetList[egIdx]                   
    targetIn = tst['input_ids'][0][target]
    targ_tokens = tokenizer.convert_ids_to_tokens(targetIn)
    targTxt = ' '.join(targ_tokens)
    # targ4captum = TextTokenInput(targTxt, tokenizer,skip_tokens=skip_tokens)
    
	llm_attr = LLMAttribution(fa, tokenizer)
	attributions_fa = llm_attr.attribute(in4captum, target=targTxt) 

but this raises an exception: prepare_inputs_for_generation isn’t
available for this BertForMultipleChoice model:

	Traceback (most recent call last):
	File "/Users/rik/Code/eclipse/ai4law/src/run_multChoice.py", line 874, in <module>
	main()
	File "/Users/rik/Code/eclipse/ai4law/src/run_multChoice.py", line 854, in main
	captumPerturb(model,tokenizer,tstEGTensorDict,tstEGtarget,OutDir)
	File "/Users/rik/Code/eclipse/ai4law/src/run_multChoice.py", line 479, in captumPerturb
	attributions_fa = llm_attr.attribute(in4captum, target=targTxt) 
	^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py", line 667, in attribute
	cur_attr = self.attr_method.attribute(
	^^^^^^^^^^^^^^^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/log/dummy_log.py", line 39, in wrapper
	return func(*args, **kwargs)
	^^^^^^^^^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute
	initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
	^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/_utils/common.py", line 588, in _run_forward
	output = forward_func(
	^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py", line 567, in _forward_func
	model_inputs = self.model.prepare_inputs_for_generation(
	^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/transformers/generation/utils.py", line 376, in prepare_inputs_for_generation
	raise NotImplementedError(

Thanks for any suggestions!

To patch around this unimplemented prepare_inputs_for_generation
error, I am now trying to construct the model inputs directly:

            #model_inputs = self.model.prepare_inputs_for_generation(
            #     model_inp, **model_kwargs

            addtl_model_inputs = {'attention_mask': attention_mask}     
            outputs = self.model.forward(model_inp, **addtl_model_inputs)

The BertForMultipleChoice model’s forward function expects input_ids, but
also attention_mask and token_type_ids.

The FeatureAblation constructor allows specification of a tailored
forward function, and that’s the strategy I’ve used previously with
LayerIntegratedGradients:

MCfwd = partial(multChoice_forward,model)
lig = LayerIntegratedGradients(MCfwd, model.bert.embeddings)

But providing it for FeatureAblation

MCfwd = partial(multChoice_forward,model)
fa = FeatureAblation(MCfwd) 

causes BaseLLMAttribution.__init__ to get confused about what’s the
model and what’s the forward_func, with this comment (L#374)

    # alias, we really need a model and don't support wrapper functions
    # coz we need call model.forward, model.generate, etc.  
	self.model: nn.Module = cast(nn.Module, self.forward_func)

and this confusion immediately causes trouble, eg, looking for the
device of the function:(

FeatureAblation allows additional_forward_args and that sounds like
what I want: “If the forward function requires additional arguments
other than the inputs for which attributions should not be computed,
this argument can be provided…”

But trying to provide this argument to the FeatureAblation constructor
generates:

	TypeError: captum.attr._core.feature_ablation.FeatureAblation.attribute() got multiple values for keyword argument 'additional_forward_args'

I note that LLMAttribution._forward_func() helpfully computes an
attention mask and adds it to model_kwargs so I only need to get the
token_type_ids into the model’s forward() method.