Passing hidden state through LayerIntegratedGradients

I was trying to replicate the CNN text model interpretability example using an LSTM. But the LayerIntegratedGradients throws an error about the missing hidden state.

forward() missing 1 required positional argument: ‘hidden’

What would be the best way of handling this?

create a function wrapper (partial function or lambda), hiding “auxiliary” parameters from captum

or use additional_forward_args argument (see docs)

1 Like