Temporal Fusion Transformer Error

I am using TFT for forecasting . It is throwing me the below error when i put return_y=True while making prediction (baseline_predictions = Baseline().predict(val_dataloader,return_y=True)). It works fine if i set return_y = False.

Error : RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 320 for tensor number 1 in the list.

Can you please suggest where am i going wrong?

Could you post a minimal and executable code snippet reproducing the error, please?

Hello,

I’m encountering a similar issue while working with the Temporal Fusion Transformer (TFT). My dataset consists of ~5,000 unique combinations (e.g., SKU, store, etc.) spanning 24 timesteps (2 years of monthly data). My goal is to forecast the next month for each combination.

When I try to calculate metrics like the Mean Absolute Error (MAE) using the Baseline().predict() function with return_y=True, I encounter the following error:

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1280 but got size 641 for tensor number 4 in the list.

If I set return_y=False, the code runs without issue. I’ve tried the following setup, but the issue persists:

# Create validation set (predict=True) to predict the last max_prediction_length points in time for each series
validation = TimeSeriesDataSet.from_dataset(training, df, predict=True, stop_randomization=True)

# Create dataloaders for the model
batch_size = 128  # Set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0, drop_last=True)

# Define baseline predictions
baseline_predictions = Baseline().predict(val_dataloader, return_y=True)

# Calculate MAE
 MAE()(baseline_predictions.output, baseline_predictions.y)

I configured the model with the following parameters:

  • max_prediction_length = 2
  • max_encoder_length = 12

It seems to be a dimension mismatch issue, but I’m unsure how to resolve it. Could anyone provide guidance on how to fix this?

Thank you!