DeepAR Training gets stuck at some random epoch

  • PyTorch-Forecasting version: # v0.10.2 Multivariate networks (23/05/2022)
  • PyTorch version: pytorch 2.0.0 py3.9_cuda11.7_cudnn8.5.0_0 pytorch
  • Python version: python 3.9.16 h7a1cb2a_2
  • Operating System: Linux, Ubuntu 22.04.2 LTS
  • GPU: RTX4090

I executed the below code to get optimal hyperparameter values by training and validating on DeepAR after every epoch using Optuna. I expected the training to run smoothly with 1 GPU. I only have 1 GPU.
Training gets stuck in some random epoch.

OPtuna has nothing to do with this. There is something wrong in training.


#from ctypes import FormatError
import numpy as np

import warnings
warnings.filterwarnings("ignore")

import os,sys

# sys.path.append(os.path.abspath(os.path.join('C:\Work\WORK_PACKAGE\Demand_forecasting\github\DeepAR-pytorch\My_model\\2_freq_nbinom_LSTM')))

# sys.path.append(os.path.abspath(os.path.join('C:\Work\WORK_PACKAGE\Demand_forecasting\github\DeepAR-pytorch\My_model\\2_freq_nbinom_LSTM\\1_cluster_demand_prediction\data\weather_data')))
# sys.path.append(os.path.abspath(os.path.join('C:\Work\WORK_PACKAGE\Demand_forecasting\github\DeepAR-pytorch\My_model\2_freq_nbinom_LSTM\1_cluster_demand_prediction\data\demand_data')))

import torch
torch.use_deterministic_algorithms(True)

from pytorch_forecasting.data.encoders import TorchNormalizer
from pytorch_forecasting.metrics import SMAPE, RMSE
from torchmetrics import R2Score, SymmetricMeanAbsolutePercentageError, MeanSquaredError

import matplotlib.pyplot as plt
import pandas as pd
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
import pytorch_lightning as pl
import torch
from pytorch_forecasting.data.encoders import TorchNormalizer
import os,sys
import numpy as np
from statsmodels.tsa.stattools import adfuller
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.graphics.tsaplots import plot_acf
from statsmodels.graphics.tsaplots import plot_pacf
from statsmodels.tsa.stattools import acf,pacf
from scipy.signal import find_peaks
import operator
import statsmodels.api as sm
from itertools import combinations
import pickle
from pytorch_forecasting import Baseline
import random
from pytorch_forecasting import DeepAR,NegativeBinomialDistributionLoss
from itertools import product
from sklearn.metrics import mean_absolute_error, mean_squared_error
import optuna
from optuna.trial import TrialState
import plotly


random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
## additional seeding to ensure reproduciblility.
pl.seed_everything(0)


Loss=NegativeBinomialDistributionLoss()

cov_lag_len= 0 

"""
Import pre-processed Data

response and target are the same thing
"""
tampines_all_clstr_train_dem_data = pd.read_csv('tampines_all_clstr_train_dem_data.csv')
tampines_all_clstr_val_dem_data = pd.read_csv('tampines_all_clstr_val_dem_data.csv')
tampines_all_clstr_test_dem_data = pd.read_csv('tampines_all_clstr_test_dem_data.csv')

tampines_all_clstr_train_dem_data = tampines_all_clstr_train_dem_data.drop(['Unnamed: 0'],axis=1)
tampines_all_clstr_val_dem_data = tampines_all_clstr_val_dem_data.drop(['Unnamed: 0'],axis=1)
tampines_all_clstr_test_dem_data = tampines_all_clstr_test_dem_data.drop(['Unnamed: 0'],axis=1)



train_data = tampines_all_clstr_train_dem_data
val_data = tampines_all_clstr_val_dem_data
test_data = tampines_all_clstr_test_dem_data


cov_lag_len = 1

#################### add date information ts ####################
#2021 oct 17 20:00:00
day=17
hour=(20+cov_lag_len)
if (20+cov_lag_len)>23:
  day = 18
  hour = hour%24

train_data["date"] = pd.Timestamp(year=2021, month=10, day=day, hour=hour ) + pd.to_timedelta(train_data.time_idx, "H")
train_data['_hour_of_day'] = train_data["date"].dt.hour.astype(str)
train_data['_day_of_week'] = train_data["date"].dt.dayofweek.astype(str)
train_data['_day_of_month'] = train_data["date"].dt.day.astype(str)
train_data['_day_of_year'] = train_data["date"].dt.dayofyear.astype(str)
train_data['_week_of_year'] = train_data["date"].dt.weekofyear.astype(str)
train_data['_month_of_year'] = train_data["date"].dt.month.astype(str)
train_data['_year'] = train_data["date"].dt.year.astype(str)
#################### add date information ts ####################



#################### add date information ts ####################
# val starts at 3/12/2021 09:00
day=3
hour=(9+cov_lag_len)
if (9+cov_lag_len)>23:
  day = 4
  hour = hour%24

val_data["date"] = pd.Timestamp(year=2021, month=12, day=day, hour=hour ) + pd.to_timedelta(val_data.time_idx, "H")
val_data['_hour_of_day'] = val_data["date"].dt.hour.astype(str)
val_data['_day_of_week'] = val_data["date"].dt.dayofweek.astype(str)
val_data['_day_of_month'] = val_data["date"].dt.day.astype(str)
val_data['_day_of_year'] = val_data["date"].dt.dayofyear.astype(str)
val_data['_week_of_year'] = val_data["date"].dt.weekofyear.astype(str)
val_data['_month_of_year'] = val_data["date"].dt.month.astype(str)
val_data['_year'] = val_data["date"].dt.year.astype(str)
#################### add date information ts ####################


#################### add date information ts ####################
# test starts at 16/12/2021 16:00
day=16
hour=(16+cov_lag_len)
if (16+cov_lag_len)>23:
  day = 17
  hour = hour%24

test_data["date"] = pd.Timestamp(year=2021, month=12, day=day, hour=hour ) + pd.to_timedelta(test_data.time_idx, "H")
test_data['_hour_of_day'] = test_data["date"].dt.hour.astype(str)
test_data['_day_of_week'] = test_data["date"].dt.dayofweek.astype(str)
test_data['_day_of_month'] = test_data["date"].dt.day.astype(str)
test_data['_day_of_year'] = test_data["date"].dt.dayofyear.astype(str)
test_data['_week_of_year'] = test_data["date"].dt.weekofyear.astype(str)
test_data['_month_of_year'] = test_data["date"].dt.month.astype(str)
test_data['_year'] = test_data["date"].dt.year.astype(str)
#################### add date information ts ####################

Target = 'target'



"""
Full Training Routine 
with bayesisan hyperparmeter search

Load data into TimeSeriesDataSet object

for fast development run
uncomment fast_dev_run = fdv_steps

"""

#early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=p, verbose=False, mode="min")
lr_logger = LearningRateMonitor()

RMSE_list = [] # FIND minimum RMSE case
hyperparams_list = [] # FIND minimum RMSE case

# best_val_comb_idx=[17,21,51,52,53,54,61,62,63,82,83,84,109,110,111,143,144,145,195,218,219,220,232,233,234,236,237,238,280,338,339,340,344,345,346,386]
# best_val_train_epochs = [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]

# best_val_comb_idx=[234]
# best_val_train_epochs = [50]


param_comb_cnt=-1
#for neu,lay,bat,lr,enc_len,pred_len,drop,cov_pair,num_ep in product(*[x for x in hparams_grid.values()]):

def objective(trial):  
  
  neu = trial.suggest_int(name="neu",low=500,high=700,step=10,log=False)
  lay = trial.suggest_int(name="lay",low=1,high=3,step=1,log=False)
  bat = trial.suggest_int(name="bat",low=4,high=8,step=4,log=False)
  lr = trial.suggest_float(name="lr",low=0.00001,high=0.0008,step=0.00001,log=False)
  enc_len = 24
  pred_len = 1
  drop = trial.suggest_float(name="dropout",low=0,high=0.2,step=0.2,log=False)
#   cov_pair = trial.suggest_categorical("cov_pair", cov_pairs_list)

  for num_ep in range(20,35,2):

    num_cols_list = []  

    cat_dict = {"_hour_of_day": NaNLabelEncoder(add_nan=True).fit(train_data._hour_of_day), \
    "_day_of_week": NaNLabelEncoder(add_nan=True).fit(train_data._day_of_week), "_day_of_month" : NaNLabelEncoder(add_nan=True).fit(train_data._day_of_month), "_day_of_year" : NaNLabelEncoder(add_nan=True).fit(train_data._day_of_year), \
        "_week_of_year": NaNLabelEncoder(add_nan=True).fit(train_data._week_of_year), "_month_of_year": NaNLabelEncoder(add_nan=True).fit(train_data._month_of_year) ,"_year": NaNLabelEncoder(add_nan=True).fit(train_data._year) }
    cat_list = ["_hour_of_day","_day_of_week","_day_of_month","_day_of_year","_week_of_year","_month_of_year","_year"]  

    num_cols_list.append('dem_lag_168') 
    num_cols_list.append('dem_lag_336')
    num_cols_list.append('inflow')
    # num_cols_list.append('inf_lag_168')  
    # num_cols_list.append('inf_lag_336')

    train_dataset = TimeSeriesDataSet(
        train_data,
        time_idx="time_idx",
        target=Target,
        categorical_encoders=cat_dict,
        group_ids=["group"],
        min_encoder_length=enc_len,
        max_encoder_length=enc_len,
        min_prediction_length=pred_len,
        max_prediction_length=pred_len,
        time_varying_unknown_reals=[Target],
        time_varying_known_reals=num_cols_list,
        time_varying_known_categoricals=cat_list,
        add_relative_time_idx=False,
        randomize_length=False,
        scalers={},
        target_normalizer=TorchNormalizer(method="identity",center=False,transformation=None )

    )

    val_dataset = TimeSeriesDataSet.from_dataset(train_dataset,val_data, stop_randomization=True, predict=False)
    test_dataset = TimeSeriesDataSet.from_dataset(train_dataset,test_data, stop_randomization=True)

    train_dataloader = train_dataset.to_dataloader(train=True, batch_size=bat)
    val_dataloader = val_dataset.to_dataloader(train=False, batch_size=bat)
    test_dataloader = test_dataset.to_dataloader(train=False, batch_size=bat)
    ######### Load DATA #############


    """
    Machine Learning predictions START
    1) DeepAR 

    """
    trainer = pl.Trainer(
        max_epochs=num_ep,
        gpus=-1, #-1
        auto_lr_find=False,
        gradient_clip_val=0.1,
        limit_train_batches=1.0,
        limit_val_batches=1.0,
        #fast_dev_run=fdv_steps,
        logger=True,
        #log_every_n_steps=10,
        # profiler=True,
        callbacks=[lr_logger]#, early_stop_callback],
        #enable_checkpointing=True,
        #default_root_dir="C:\Work\WORK_PACKAGE\Demand_forecasting\github\DeepAR-pytorch\My_model\2_freq_nbinom_LSTM\1_cluster_demand_prediction\logs"
    )


    #print(f"training routing:\n \n {trainer}")
    deepar = DeepAR.from_dataset(
        train_dataset,
        learning_rate=lr,
        hidden_size=neu,
        rnn_layers=lay,
        dropout=drop,
        loss=Loss,
        log_interval=20,
        log_val_interval=6,
        log_gradient_flow=False,
        # reduce_on_plateau_patience=3,
    )

    #print(f"Number of parameters in network: {deepar.size()/1e3:.1f}k")
    # print(f"Model :\n \n {deepar}")
    torch.set_num_threads(10)
    trainer.fit(
        deepar,
        train_dataloaders=train_dataloader,
        val_dataloaders=val_dataloader,
    )

    val_rmse = trainer.callback_metrics['val_RMSE'].item()
    trial.report(val_rmse, num_ep)

    # Handle pruning based on the intermediate value.
    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()

  return val_rmse

########## optuna results #####################
if __name__ == "__main__":

  study = optuna.create_study(direction="minimize")
  study.optimize(objective, timeout=6000)

  pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
  complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

  print("Study statistics: ")
  print("  Number of finished trials: ", len(study.trials))
  print("  Number of pruned trials: ", len(pruned_trials))
  print("  Number of complete trials: ", len(complete_trials))

  print("Best trial:")
  trial = study.best_trial

  print("  Value: ", trial.value)

  print("  Params: ")
  for key, value in trial.params.items():
      print("    {}: {}".format(key, value))

  fig = optuna.visualization.plot_parallel_coordinate(study)
  fig.show()

  fig = optuna.visualization.plot_optimization_history(study)
  fig.show()

  fig = optuna.visualization.plot_slice(study)
  fig.show()

  fig = optuna.visualization.plot_param_importances(study)
  fig.show()
########## optuna results #####################

The versions of packages in conda environment:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
absl-py                   1.4.0              pyhd8ed1ab_0    conda-forge
aiohttp                   3.8.4            py39h72bdee0_0    conda-forge
aiosignal                 1.3.1              pyhd8ed1ab_0    conda-forge
alembic                   1.10.2             pyhd8ed1ab_0    conda-forge
appdirs                   1.4.4              pyhd3eb1b0_0  
asttokens                 2.2.1              pyhd8ed1ab_0    conda-forge
async-timeout             4.0.2              pyhd8ed1ab_0    conda-forge
attrs                     22.2.0             pyh71513ae_0    conda-forge
autopage                  0.5.1              pyhd8ed1ab_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                pyhd8ed1ab_3    conda-forge
backports.functools_lru_cache 1.6.4              pyhd8ed1ab_0    conda-forge
blas                      1.0                         mkl  
blinker                   1.5                pyhd8ed1ab_0    conda-forge
bottleneck                1.3.5            py39h7deecbd_0  
brotli                    1.0.9                h5eee18b_7  
brotli-bin                1.0.9                h5eee18b_7  
brotlipy                  0.7.0           py39hb9d737c_1005    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
c-ares                    1.18.1               h7f98852_0    conda-forge
ca-certificates           2023.01.10           h06a4308_0    anaconda
cachetools                5.3.0              pyhd8ed1ab_0    conda-forge
certifi                   2022.12.7        py39h06a4308_0    anaconda
cffi                      1.15.1           py39he91dace_3    conda-forge
charset-normalizer        2.1.1              pyhd8ed1ab_0    conda-forge
click                     8.1.3           unix_pyhd8ed1ab_2    conda-forge
cliff                     4.2.0              pyhd8ed1ab_0    conda-forge
cmaes                     0.9.1              pyhd8ed1ab_0    conda-forge
cmd2                      2.4.3            py39hf3d152e_0    conda-forge
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
colorlog                  6.7.0            py39hf3d152e_1    conda-forge
comm                      0.1.3              pyhd8ed1ab_0    conda-forge
contourpy                 1.0.5            py39hdb19cb5_0  
cryptography              39.0.0           py39hd598818_0    conda-forge
cuda-cudart               11.7.99                       0    nvidia
cuda-cupti                11.7.101                      0    nvidia
cuda-libraries            11.7.1                        0    nvidia
cuda-nvrtc                11.7.99                       0    nvidia
cuda-nvtx                 11.7.91                       0    nvidia
cuda-runtime              11.7.1                        0    nvidia
cudatoolkit               11.6.0              hecad31d_11    conda-forge
cycler                    0.11.0             pyhd3eb1b0_0  
dbus                      1.13.18              hb2f20db_0  
debugpy                   1.6.6            py39h227be39_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
executing                 1.2.0              pyhd8ed1ab_0    conda-forge
expat                     2.4.9                h6a678d5_0  
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.9.0            py39h06a4308_0  
fontconfig                2.14.1               h52c9d5c_1  
fonttools                 4.25.0             pyhd3eb1b0_0  
freetype                  2.10.4               h0708190_1    conda-forge
frozenlist                1.3.3            py39hb9d737c_0    conda-forge
fsspec                    2023.3.0           pyhd8ed1ab_1    conda-forge
future                    0.18.3             pyhd8ed1ab_0    conda-forge
giflib                    5.2.1                h5eee18b_3  
glib                      2.69.1               he621ea3_2  
gmp                       6.2.1                h58526e2_0    conda-forge
gmpy2                     2.1.2            py39heeb90bb_0  
gnutls                    3.6.13               h85f3911_1    conda-forge
google-auth               2.16.2             pyh1a96a4e_0    conda-forge
google-auth-oauthlib      0.4.6              pyhd8ed1ab_0    conda-forge
greenlet                  2.0.2            py39h227be39_0    conda-forge
grpcio                    1.38.1           py39hff7568b_0    conda-forge
gst-plugins-base          1.14.1               h6a678d5_1  
gstreamer                 1.14.1               h5eee18b_1  
icu                       58.2                 he6710b0_3  
idna                      3.4                pyhd8ed1ab_0    conda-forge
importlib-metadata        4.13.0             pyha770c72_0    conda-forge
importlib_metadata        4.13.0               hd8ed1ab_0    conda-forge
importlib_resources       5.2.0              pyhd3eb1b0_1  
intel-openmp              2021.4.0          h06a4308_3561  
ipykernel                 6.19.2           py39hb070fc8_0    anaconda
ipython                   8.11.0             pyh41d4057_0    conda-forge
jedi                      0.18.2             pyhd8ed1ab_0    conda-forge
jinja2                    3.1.2            py39h06a4308_0  
joblib                    1.1.1            py39h06a4308_0    anaconda
jpeg                      9e                   h0b41bf4_3    conda-forge
jsonschema                4.17.3           py39h06a4308_0  
jupyter_client            8.1.0              pyhd8ed1ab_0    conda-forge
jupyter_core              5.3.0            py39hf3d152e_0    conda-forge
kiwisolver                1.4.4            py39h6a678d5_0  
krb5                      1.19.4               h568e23c_0  
lame                      3.100             h166bdaf_1003    conda-forge
lcms2                     2.12                 hddcbb42_0    conda-forge
ld_impl_linux-64          2.38                 h1181459_1  
libbrotlicommon           1.0.9                h5eee18b_7  
libbrotlidec              1.0.9                h5eee18b_7  
libbrotlienc              1.0.9                h5eee18b_7  
libclang                  10.0.1          default_hb85057a_2  
libcublas                 11.10.3.66                    0    nvidia
libcufft                  10.7.2.124           h4fbf590_0    nvidia
libcufile                 1.6.0.25                      0    nvidia
libcurand                 10.3.2.56                     0    nvidia
libcusolver               11.4.0.1                      0    nvidia
libcusparse               11.7.4.91                     0    nvidia
libedit                   3.1.20221030         h5eee18b_0  
libevent                  2.1.12               h8f2d780_0  
libffi                    3.4.2                h6a678d5_6  
libgcc-ng                 12.2.0              h65d4601_19    conda-forge
libgfortran-ng            7.5.0               ha8ba4b0_17  
libgfortran4              7.5.0               ha8ba4b0_17  
libiconv                  1.17                 h166bdaf_0    conda-forge
libllvm10                 10.0.1               hbcb73fb_5  
libnpp                    11.7.4.75                     0    nvidia
libnvjpeg                 11.8.0.2                      0    nvidia
libpng                    1.6.37               h21135ba_2    conda-forge
libpq                     12.9                 h16c4e8d_3  
libprotobuf               3.18.0               h780b84a_1    conda-forge
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libstdcxx-ng              12.2.0              h46fd767_19    conda-forge
libtiff                   4.2.0                hf544144_3    conda-forge
libuuid                   1.41.5               h5eee18b_0  
libwebp                   1.2.4                h11a3e52_1  
libwebp-base              1.2.4                h5eee18b_1  
libxcb                    1.15                 h7f8727e_0  
libxkbcommon              1.0.1                hfa300c1_0  
libxml2                   2.9.14               h74e7548_0  
libxslt                   1.1.35               h4e12654_0  
lightning-utilities       0.8.0              pyhd8ed1ab_0    conda-forge
llvm-openmp               14.0.6               h9e868ea_0  
lz4-c                     1.9.3                h9c3ff4c_1    conda-forge
mako                      1.2.4              pyhd8ed1ab_0    conda-forge
markdown                  3.4.3              pyhd8ed1ab_0    conda-forge
markupsafe                2.1.2            py39h72bdee0_0    conda-forge
matplotlib                3.4.3            py39h06a4308_0  
matplotlib-base           3.4.3            py39hbbc1b5f_0  
matplotlib-inline         0.1.6              pyhd8ed1ab_0    conda-forge
mkl                       2021.4.0           h8d4b97c_729    conda-forge
mkl-service               2.4.0            py39h7e14d7c_0    conda-forge
mkl_fft                   1.3.1            py39h0c7bc48_1    conda-forge
mkl_random                1.2.2            py39hde0f152_0    conda-forge
mpc                       1.1.0                h10f8cd9_1  
mpfr                      4.0.2                hb69a4c5_1  
mpmath                    1.2.1            py39h06a4308_0  
multidict                 6.0.4            py39h72bdee0_0    conda-forge
munkres                   1.1.4                      py_0  
nbformat                  5.7.0            py39h06a4308_0  
ncurses                   6.4                  h6a678d5_0  
nest-asyncio              1.5.6              pyhd8ed1ab_0    conda-forge
nettle                    3.6                  he412f7d_0    conda-forge
networkx                  2.8.4            py39h06a4308_1  
nspr                      4.33                 h295c915_0  
nss                       3.74                 h0370c37_0  
numexpr                   2.8.4            py39he184ba9_0  
numpy                     1.21.5           py39h6c91a56_3  
numpy-base                1.21.5           py39ha15fc14_3  
oauthlib                  3.2.2              pyhd8ed1ab_0    conda-forge
olefile                   0.46               pyh9f0ad1d_1    conda-forge
openh264                  2.1.1                h780b84a_0    conda-forge
openjpeg                  2.4.0                hb52868f_1    conda-forge
openssl                   1.1.1t               h7f8727e_0  
optuna                    3.0.5              pyhd8ed1ab_0    conda-forge
packaging                 23.0               pyhd8ed1ab_0    conda-forge
pandas                    1.3.4            py39h8c16a72_0  
parso                     0.8.3              pyhd8ed1ab_0    conda-forge
patsy                     0.5.3            py39h06a4308_0  
pbr                       5.11.1             pyhd8ed1ab_0    conda-forge
pcre                      8.45                 h295c915_0  
pexpect                   4.8.0              pyh1a96a4e_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    8.2.0            py39hf95b381_1    conda-forge
pip                       23.0.1           py39h06a4308_0  
platformdirs              3.1.1              pyhd8ed1ab_0    conda-forge
plotly                    5.11.0             pyhd8ed1ab_1    conda-forge
ply                       3.11             py39h06a4308_0  
pooch                     1.4.0              pyhd3eb1b0_0  
prettytable               3.6.0              pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.38             pyha770c72_0    conda-forge
prompt_toolkit            3.0.38               hd8ed1ab_0    conda-forge
protobuf                  3.18.0           py39he80948d_0    conda-forge
psutil                    5.9.4            py39hb9d737c_0    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
pyasn1                    0.4.8                      py_0    conda-forge
pyasn1-modules            0.2.7                      py_0    conda-forge
pycparser                 2.21               pyhd8ed1ab_0    conda-forge
pydeprecate               0.3.2              pyhd8ed1ab_0    conda-forge
pygments                  2.14.0             pyhd8ed1ab_0    conda-forge
pyjwt                     2.6.0              pyhd8ed1ab_0    conda-forge
pyopenssl                 23.0.0             pyhd8ed1ab_0    conda-forge
pyparsing                 3.0.9            py39h06a4308_0  
pyperclip                 1.8.2              pyhd8ed1ab_2    conda-forge
pyqt                      5.15.7           py39h6a678d5_1  
pyqt5-sip                 12.11.0          py39h6a678d5_1  
pyrsistent                0.18.0           py39heee7806_0  
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.9.16               h7a1cb2a_2  
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python-fastjsonschema     2.16.2           py39h06a4308_0  
python_abi                3.9                      2_cp39    conda-forge
pytorch                   2.0.0           py3.9_cuda11.7_cudnn8.5.0_0    pytorch
pytorch-cuda              11.7                 h778d358_3    pytorch
pytorch-lightning         1.7.2              pyhd8ed1ab_0    conda-forge
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2022.7           py39h06a4308_0  
pyu2f                     0.1.5              pyhd8ed1ab_0    conda-forge
pyyaml                    6.0              py39hb9d737c_5    conda-forge
pyzmq                     25.0.2           py39h0be026e_0    conda-forge
qt-main                   5.15.2               h327a75a_7  
qt-webengine              5.15.9               hd2b0992_4  
qtwebkit                  5.212                h4eab89a_4  
readline                  8.2                  h5eee18b_0  
requests                  2.28.2             pyhd8ed1ab_0    conda-forge
requests-oauthlib         1.3.1              pyhd8ed1ab_0    conda-forge
rsa                       4.9                pyhd8ed1ab_0    conda-forge
scikit-learn              1.0.2            py39h51133e4_1  
scipy                     1.7.1            py39h292c36d_2  
setuptools                59.5.0           py39hf3d152e_0    conda-forge
sip                       6.6.2            py39h6a678d5_0  
six                       1.16.0             pyh6c4a22f_0    conda-forge
sqlalchemy                2.0.7            py39h72bdee0_0    conda-forge
sqlite                    3.41.1               h5eee18b_0  
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
statsmodels               0.13.2           py39h7f8727e_0  
stevedore                 5.0.0              pyhd8ed1ab_0    conda-forge
sympy                     1.11.1           py39h06a4308_0  
tbb                       2021.7.0             h924138e_0    conda-forge
tenacity                  8.2.2              pyhd8ed1ab_0    conda-forge
tensorboard               2.11.2             pyhd8ed1ab_0    conda-forge
tensorboard-data-server   0.6.1            py39hd97740a_4    conda-forge
tensorboard-plugin-wit    1.8.1              pyhd8ed1ab_0    conda-forge
threadpoolctl             2.2.0              pyh0d69192_0    anaconda
tk                        8.6.12               h1ccaba5_0  
toml                      0.10.2             pyhd3eb1b0_0  
torchaudio                2.0.0                py39_cu117    pytorch
torchmetrics              0.9.3              pyhd8ed1ab_0    conda-forge
torchtriton               2.0.0                      py39    pytorch
torchvision               0.15.0               py39_cu117    pytorch
tornado                   6.2              py39hb9d737c_1    conda-forge
tqdm                      4.65.0             pyhd8ed1ab_1    conda-forge
traitlets                 5.9.0              pyhd8ed1ab_0    conda-forge
typing-extensions         4.5.0                hd8ed1ab_0    conda-forge
typing_extensions         4.5.0              pyha770c72_0    conda-forge
tzdata                    2022g                h04d1e81_0  
urllib3                   1.26.15            pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.6              pyhd8ed1ab_0    conda-forge
werkzeug                  2.2.3              pyhd8ed1ab_0    conda-forge
wheel                     0.38.4           py39h06a4308_0  
xz                        5.2.10               h5eee18b_1  
yaml                      0.2.5                h7f98852_2    conda-forge
yarl                      1.8.2            py39hb9d737c_0    conda-forge
zeromq                    4.3.4                h9c3ff4c_1    conda-forge
zipp                      3.15.0             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               h5eee18b_0  
zstd                      1.5.2                ha4553b6_0  


nvidia-smi duing training:

Wed Mar 29 11:46:38 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  Off |
|  0%   37C    P2    76W / 450W |   1258MiB / 24564MiB |     33%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      2030      G   /usr/lib/xorg/Xorg                407MiB |
|    0   N/A  N/A      2163      G   /usr/bin/gnome-shell               49MiB |
|    0   N/A  N/A     22794      G   ...870274336488489090,131072      105MiB |
|    0   N/A  N/A     26677      G   ...RendererForSitePerProcess       35MiB |
|    0   N/A  N/A     29747      G   /usr/bin/totem                     15MiB |
|    0   N/A  N/A     32739      C   ...nvs/deepar-gpu/bin/python      640MiB |
+-----------------------------------------------------------------------------+

I solved the problem. solution is here: DeepAR Training gets stuck at some random epoch · Issue #1281 · jdb78/pytorch-forecasting · GitHub