Feature attribution for darts model

Hi, I’ve been working on FeatureExplainer for probabilistic TiDE model and eventually came up with this solution:

class PlainProbabilisticTiDEModel(nn.Module):
    def __init__(self, model: TiDEModel, num_samples: int):
        super().__init__()
        self.model = model.model
        self.likelihood = model.likelihood
        self.batch_size = model.batch_size
        self.num_samples = num_samples

    def forward(
        self,
        x_past: torch.Tensor,
        x_future: torch.Tensor,
        x_static: torch.Tensor
    ) -> torch.Tensor:
        num_series = x_past.shape[0]
        batch_sample_size = min(
            max(self.batch_size // num_series, 1), self.num_samples
        )

        batch_predictions, sample_count = [], 0
        while sample_count < self.num_samples:
            # make sure we don't produce too many samples
            if sample_count + batch_sample_size > self.num_samples:
                batch_sample_size = self.num_samples - sample_count

            # stack multiple copies of the tensors to produce probabilistic forecasts
            input_data_tuple_samples = self._sample_tiling(
                (x_past, x_future, x_static), batch_sample_size
            )

            # get predictions for 1 whole batch (can include predictions of multiple series
            # and for multiple samples if a probabilistic forecast is produced)
            output = self.model(input_data_tuple_samples)
            # (batch_size, n_timestamps, n_components, n_lh_params)
            output = self.likelihood.sample(output)

            # reshape from 3d tensor (num_series x batch_sample_size, ...)
            # into 4d tensor (batch_sample_size, num_series, ...), where dim 0 represents the samples
            out_shape = output.shape
            output = output.reshape(
                (
                    batch_sample_size,
                    num_series,
                )
                + out_shape[1:]
            )

            # save all predictions and update the `sample_count` variable
            batch_predictions.append(output)
            sample_count += batch_sample_size

        batch_predictions = torch.cat(batch_predictions, dim=0)
        return batch_predictions.median(0).values


class TideExplainer:
    
    def __init__(self, model: TiDEModel, num_samples: int):
        self.model = PlainProbabilisticTiDEModel(
            model=model,
            num_samples=num_samples
        )
        self._collate_fn = model._batch_collate_fn
        self.uses_static_covariates = model.uses_static_covariates

        model_wrapped = ModelInputWrapper(self.model)
        self.method = FeatureAblation(model_wrapped)

    def explain(
        self,
        target_series: Union[TimeSeries, Sequence[TimeSeries]],
        past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
        future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
        n: int,
        input_chunk_length: int,
        output_chunk_length: int,
        batch_size: int = 1,
        perturbations_per_eval: int = 100,
        verbose: bool = True
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        Create feature importances for given time series data.
        
        Parameters
        ----------
        target_series
            The target series that are to be predicted into the future.
        past_covariates
            Some past-observed covariates that are used for predictions.
        future_covariates
            Some future-known covariates that are used for predictions.
        n
            Forecast horizon: The number of time steps to predict after the end of the target series.
        input_chunk_length
            The length of the target series the model takes as input.
        output_chunk_length
            The length of the target series the model emits in output.
        batch_size
            Number of time series to process simultaneously.
        perturbations_per_eval
            Feature ablation (https://captum.ai/api/feature_ablation.html) parameter.
        verbose
            Whether to show progress bar for data loader.
        
        Returns
        -------
        time_importances
            ``np.ndarray`` of shape `(batch_size, input_chunk_length)`,
            which describes the dependence of the forecast on each historical timestep.
        past_importances
            ``np.ndarray`` of shape `(batch_size, n_targets + n_past_features + n_future_features)`,
            which describes the dependence of the forecast on each past feature at historical timesteps.
        future_importances
            ``np.ndarray`` of shape `(batch_size, n_future_features)`,
            which describes the dependence of the forecast on each future feature at future timesteps.
        static_importances
            ``np.ndarray`` of shape `(batch_size, n_static_features)`,
            which describes the dependence of the forecast on each static feature.
        """
        inference_dataset = MixedCovariatesInferenceDataset(
            target_series=target_series,
            past_covariates=past_covariates,
            future_covariates=future_covariates,
            n=n,
            input_chunk_length=input_chunk_length,
            output_chunk_length=output_chunk_length,
            use_static_covariates=self.uses_static_covariates
        )
        loader = DataLoader(
            inference_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False,
            collate_fn=self._collate_fn,
        )
        
        time_importances, past_importances, future_importances, static_importances = [], [], [] ,[]
        for batch in tqdm(loader, display=verbose):
            (
                x_past_target,
                x_past_covariates,
                x_historic_future_covariates,
                x_future_covariates,
                x_future_past_covariates,
                x_static_covariates,
                _
            ) = batch

            input_batch = self._process_input_batch(
                input_batch=(
                    x_past_target,
                    x_past_covariates,
                    x_historic_future_covariates,
                    x_future_covariates,
                    x_static_covariates
                )
            )

            # calculate input attribute to forecast
            # x_past_attrs: 
            # (batch_size * forecast_horizon, input_length, n_targets + n_past_features + n_future_features)
            # x_future_attrs: (batch_size * forecast_horizon, input_length, n_future_features)
            # x_static_attrs: (batch_size * forecast_horizon, n_targets, n_static_features)
            x_past_attrs, x_future_attr, x_static_attr = self.method.attribute(
                inputs=input_batch,
                perturbations_per_eval=perturbations_per_eval
            )

            # reshape attributes from (batch_size * forecast_horizon, ...) 
            # to (batch_size, num_timestamps, ...)
            x_past_attrs = x_past_attrs.reshape(batch_size, -1, *x_past_attrs.shape[1:])
            x_future_attr = x_future_attr.reshape(batch_size, -1, *x_future_attr.shape[1:])
            x_static_attr = x_static_attr.reshape(batch_size, -1, *x_static_attr.shape[1:])

            # calculate feature importances
            time_importance = x_past_attrs.sum(dim=3).sum(dim=1)  # (batch_size, input_length)
            past_importance = x_past_attrs.sum(dim=2).sum(dim=1)
            # (batch_size, n_targets + n_past_features + n_future_features)
            future_importance = x_future_attr.sum(dim=2).sum(dim=1)  # (batch_size, n_future_features)
            static_importance = x_static_attr.sum(dim=2).sum(dim=1)  # (batch_size, n_static_features)

            time_importances.append(time_importance)
            past_importances.append(past_importance)
            future_importances.append(future_importance)
            static_importances.append(static_importance)

        time_importances = torch.cat(time_importances, dim=0).numpy()
        past_importances = torch.cat(past_importances, dim=0).numpy()
        future_importances = torch.cat(future_importances, dim=0).numpy()
        static_importances = torch.cat(static_importances, dim=0).numpy()

        return time_importances, past_importances, future_importances, static_importances

Basically, it’s feature ablation approach with additional darts-to-torch transformations.

My question is: is there any way to apply gradient-based attribution method to a model with output of shape (batch_size, n_timestamps, n_targets)? If not, is a perturbation/permutation method the only option for a forecasting model?

2 Likes