Hi, I’ve been working on FeatureExplainer for probabilistic TiDE model and eventually came up with this solution:

```
class PlainProbabilisticTiDEModel(nn.Module):
def __init__(self, model: TiDEModel, num_samples: int):
super().__init__()
self.model = model.model
self.likelihood = model.likelihood
self.batch_size = model.batch_size
self.num_samples = num_samples
def forward(
self,
x_past: torch.Tensor,
x_future: torch.Tensor,
x_static: torch.Tensor
) -> torch.Tensor:
num_series = x_past.shape[0]
batch_sample_size = min(
max(self.batch_size // num_series, 1), self.num_samples
)
batch_predictions, sample_count = [], 0
while sample_count < self.num_samples:
# make sure we don't produce too many samples
if sample_count + batch_sample_size > self.num_samples:
batch_sample_size = self.num_samples - sample_count
# stack multiple copies of the tensors to produce probabilistic forecasts
input_data_tuple_samples = self._sample_tiling(
(x_past, x_future, x_static), batch_sample_size
)
# get predictions for 1 whole batch (can include predictions of multiple series
# and for multiple samples if a probabilistic forecast is produced)
output = self.model(input_data_tuple_samples)
# (batch_size, n_timestamps, n_components, n_lh_params)
output = self.likelihood.sample(output)
# reshape from 3d tensor (num_series x batch_sample_size, ...)
# into 4d tensor (batch_sample_size, num_series, ...), where dim 0 represents the samples
out_shape = output.shape
output = output.reshape(
(
batch_sample_size,
num_series,
)
+ out_shape[1:]
)
# save all predictions and update the `sample_count` variable
batch_predictions.append(output)
sample_count += batch_sample_size
batch_predictions = torch.cat(batch_predictions, dim=0)
return batch_predictions.median(0).values
class TideExplainer:
def __init__(self, model: TiDEModel, num_samples: int):
self.model = PlainProbabilisticTiDEModel(
model=model,
num_samples=num_samples
)
self._collate_fn = model._batch_collate_fn
self.uses_static_covariates = model.uses_static_covariates
model_wrapped = ModelInputWrapper(self.model)
self.method = FeatureAblation(model_wrapped)
def explain(
self,
target_series: Union[TimeSeries, Sequence[TimeSeries]],
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
n: int,
input_chunk_length: int,
output_chunk_length: int,
batch_size: int = 1,
perturbations_per_eval: int = 100,
verbose: bool = True
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Create feature importances for given time series data.
Parameters
----------
target_series
The target series that are to be predicted into the future.
past_covariates
Some past-observed covariates that are used for predictions.
future_covariates
Some future-known covariates that are used for predictions.
n
Forecast horizon: The number of time steps to predict after the end of the target series.
input_chunk_length
The length of the target series the model takes as input.
output_chunk_length
The length of the target series the model emits in output.
batch_size
Number of time series to process simultaneously.
perturbations_per_eval
Feature ablation (https://captum.ai/api/feature_ablation.html) parameter.
verbose
Whether to show progress bar for data loader.
Returns
-------
time_importances
``np.ndarray`` of shape `(batch_size, input_chunk_length)`,
which describes the dependence of the forecast on each historical timestep.
past_importances
``np.ndarray`` of shape `(batch_size, n_targets + n_past_features + n_future_features)`,
which describes the dependence of the forecast on each past feature at historical timesteps.
future_importances
``np.ndarray`` of shape `(batch_size, n_future_features)`,
which describes the dependence of the forecast on each future feature at future timesteps.
static_importances
``np.ndarray`` of shape `(batch_size, n_static_features)`,
which describes the dependence of the forecast on each static feature.
"""
inference_dataset = MixedCovariatesInferenceDataset(
target_series=target_series,
past_covariates=past_covariates,
future_covariates=future_covariates,
n=n,
input_chunk_length=input_chunk_length,
output_chunk_length=output_chunk_length,
use_static_covariates=self.uses_static_covariates
)
loader = DataLoader(
inference_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=False,
collate_fn=self._collate_fn,
)
time_importances, past_importances, future_importances, static_importances = [], [], [] ,[]
for batch in tqdm(loader, display=verbose):
(
x_past_target,
x_past_covariates,
x_historic_future_covariates,
x_future_covariates,
x_future_past_covariates,
x_static_covariates,
_
) = batch
input_batch = self._process_input_batch(
input_batch=(
x_past_target,
x_past_covariates,
x_historic_future_covariates,
x_future_covariates,
x_static_covariates
)
)
# calculate input attribute to forecast
# x_past_attrs:
# (batch_size * forecast_horizon, input_length, n_targets + n_past_features + n_future_features)
# x_future_attrs: (batch_size * forecast_horizon, input_length, n_future_features)
# x_static_attrs: (batch_size * forecast_horizon, n_targets, n_static_features)
x_past_attrs, x_future_attr, x_static_attr = self.method.attribute(
inputs=input_batch,
perturbations_per_eval=perturbations_per_eval
)
# reshape attributes from (batch_size * forecast_horizon, ...)
# to (batch_size, num_timestamps, ...)
x_past_attrs = x_past_attrs.reshape(batch_size, -1, *x_past_attrs.shape[1:])
x_future_attr = x_future_attr.reshape(batch_size, -1, *x_future_attr.shape[1:])
x_static_attr = x_static_attr.reshape(batch_size, -1, *x_static_attr.shape[1:])
# calculate feature importances
time_importance = x_past_attrs.sum(dim=3).sum(dim=1) # (batch_size, input_length)
past_importance = x_past_attrs.sum(dim=2).sum(dim=1)
# (batch_size, n_targets + n_past_features + n_future_features)
future_importance = x_future_attr.sum(dim=2).sum(dim=1) # (batch_size, n_future_features)
static_importance = x_static_attr.sum(dim=2).sum(dim=1) # (batch_size, n_static_features)
time_importances.append(time_importance)
past_importances.append(past_importance)
future_importances.append(future_importance)
static_importances.append(static_importance)
time_importances = torch.cat(time_importances, dim=0).numpy()
past_importances = torch.cat(past_importances, dim=0).numpy()
future_importances = torch.cat(future_importances, dim=0).numpy()
static_importances = torch.cat(static_importances, dim=0).numpy()
return time_importances, past_importances, future_importances, static_importances
```

Basically, it’s feature ablation approach with additional darts-to-torch transformations.

My question is: is there any way to apply gradient-based attribution method to a model with output of shape `(batch_size, n_timestamps, n_targets)`

? If not, is a perturbation/permutation method the only option for a forecasting model?