Target not provided when necessary, cannot take gradient with respect to multiple outputs?

:bug: Bug

Using the simplest BERT model on a binary text classification task, I fail on a call to lig.attribute()?!

To Reproduce

(I鈥檝e put the full script here: rikHak/captum-numelBug.py at master 路 rbelew/rikHak 路 GitHub)

I鈥檓 using the overruling dataset is a small (n=2400, evenly split) binary classification task. I鈥檓 also using a tiny BERT model (4.4M params).

I鈥檓 also assuming (following this FAQ that I don鈥檛 need to explicitly provide a target.

My call to attribute() looks like this:

attributions_ig, delta = lig.attribute(in_tensor, reference_indices, \
					additional_forward_args (ttype_tensor,attn_tensor), \
                                        n_steps=500, return_convergence_delta=True)

  • It makes it to LayerIntegratedGradients.gradient_func L#480 and then
    fails on this assertion:

              assert output[0].numel() == 1, (
                  "Target not provided when necessary, cannot"
                  " take gradient with respect to multiple outputs."
              )
    
  • Indeed, output[0].numel() = 1000; output[0].shape = torch.Size([500, 2])

  • Looking back up the stack, _attribute() is getting the right
    inputs argument and inputs[0].shape is [1,128,128] but the
    scaled_features_tpl passed on as inputs to gradient_func as
    inputs has shape [500, 128,128] ?! I also notice that the inputs
    have ALSO been prepended (redundantly?) onto additional_forward_arguments?

(I鈥檝e also posted this as a captum issue here .)

Expected behavior

I鈥檝e made captum work on some of the simple examples (Titanic, IMDB) and thought this would work, too?! What am I doing wrong, please?

Environment

PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.1.0.2.5)
CMake version: Could not collect
Libc version: N/A

Python version: 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.12.1
[pip3] torch==2.3.0
[pip3] torchdata==0.8.0
[pip3] torchtext==0.18.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] torch 2.3.0 pypi_0 pypi
[conda] torchdata 0.8.0 pypi_0 pypi
[conda] torchtext 0.18.0 pypi_0 pypi

come on folks! this is a simple, well-documented example of basic captum features, and I am stuck. I鈥檝e put the full script here: rikHak/captum-numelBug.py at master 路 rbelew/rikHak 路 GitHub Please help!