Temporal fusion transformer training on colab TPU

I was training a TFT model on a colab GPU. It trained, but still it was relatively slow because TFT is a big model. I wanted to train the model on a oclab TPU, but it cannot get started. It gets to Epoch 0 and it freezes. My questions are: Is the code below ok in terms of TPU utilization?


max_prediction_length = len(test)
max_encoder_length = 4*max_prediction_length
# training_cutoff = df_19344_tmp["time_idx"].max() - max_prediction_length


training = TimeSeriesDataSet(
    train,
    time_idx='time_idx',
    target='occupancy',
    group_ids=['property_id'],
    min_encoder_length=max_encoder_length,
    max_encoder_length=max_encoder_length,
    min_prediction_length=max_prediction_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=['property_id'],
    static_reals=[],
    time_varying_known_categoricals=[],
    time_varying_known_reals=['time_idx', 'sin_day', 'cos_day', 'sin_month', 'cos_month', 'sin_year', 'cos_year'],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[],
    target_normalizer=GroupNormalizer(
    groups=['property_id'], transformation="softplus"
    ), 
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True
)


validation = TimeSeriesDataSet.from_dataset(training, train, predict=True, stop_randomization=True)

batch_size = 32  # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)


trainer = pl.Trainer(
    max_epochs=100,
    # accelerator='cpu',
    accelerator='tpu',
    devices=1,
    enable_model_summary=True,
    auto_lr_find=False,
    # clipping gradients is a hyperparameter and important to prevent divergance
    # of the gradient for recurrent neural networks
    gradient_clip_val=0.1,
    check_val_every_n_epoch=None

)

tft = TemporalFusionTransformer.from_dataset(
    training,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=0.0005,
    hidden_size=8,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=1,
    dropout=0.1,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=8,  # set to <= hidden_size
    output_size=7,  # 7 quantiles by default
    loss=QuantileLoss(),
)

And my second question is, does TFT from pytorch-forecasting even support TPU training?

This is where the model freezes when training on a colab TPU: