Unable to reinitialize custom model after running lr_find

I am trying to create a custom model in pytorch but after running lr_find on the model I am seeing below error.

“RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment”

I am not able to figure out why running lr_find would prevent the model from getting reinitialized using the .from_dataset() method of the model itself. Given that the initialization is happining without any issues prior to doing so. If I dont use lr_find and simply proceed to training, i face no issues in training the model.

Below is the code:

import torch
from torch import nn


class FullyConnectedModule(nn.Module):
    def __init__(self, input_size: int, output_size: int, hidden_size: int, n_hidden_layers: int):
        super().__init__()

        # input layer
        module_list = [nn.Linear(input_size, hidden_size), nn.ReLU()]
        # hidden layers
        for _ in range(n_hidden_layers):
            module_list.extend([nn.Linear(hidden_size, hidden_size), nn.ReLU()])
        # output layer
        module_list.append(nn.Linear(hidden_size, output_size))

        self.sequential = nn.Sequential(*module_list)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x of shape: batch_size x n_timesteps_in
        # output of shape batch_size x n_timesteps_out
        return self.sequential(x)

from typing import Dict

from pytorch_forecasting.models import BaseModel


class FullyConnectedModel(BaseModel):
    def __init__(self, input_size: int, output_size: int, hidden_size: int, n_hidden_layers: int, **kwargs):
        # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
        self.save_hyperparameters()
        # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
        super().__init__(**kwargs)
        self.network = FullyConnectedModule(
            input_size=self.hparams.input_size,
            output_size=self.hparams.output_size,
            hidden_size=self.hparams.hidden_size,
            n_hidden_layers=self.hparams.n_hidden_layers,
        )
        
        
    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # x is a batch generated based on the TimeSeriesDataset
        network_input = x["encoder_cont"].squeeze(-1)
        prediction = self.network(network_input).unsqueeze(-1)

        # rescale predictions into target space
        prediction = self.transform_output(prediction, target_scale=x["target_scale"])

        # We need to return a dictionary that at least contains the prediction.
        # The parameter can be directly forwarded from the input.
        # The conversion to a named tuple can be directly achieved with the `to_network_output` function.
        return self.to_network_output(prediction=prediction)

    @classmethod
    def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
        new_kwargs = {
            "output_size": dataset.max_prediction_length,
            "input_size": dataset.max_encoder_length,
        }
        new_kwargs.update(kwargs)  # use to pass real hyperparameters and override defaults set by dataset
        # example for dataset validation
        assert dataset.max_prediction_length == dataset.min_prediction_length, "Decoder only supports a fixed length"
        assert dataset.min_encoder_length == dataset.max_encoder_length, "Encoder only supports a fixed length"
        assert (
            len(dataset.time_varying_known_categoricals) == 0
            and len(dataset.time_varying_known_reals) == 0
            and len(dataset.time_varying_unknown_categoricals) == 0
            and len(dataset.static_categoricals) == 0
            and len(dataset.static_reals) == 0
            and len(dataset.time_varying_unknown_reals) == 1
            and dataset.time_varying_unknown_reals[0] == dataset.target
        ), "Only covariate should be the target in 'time_varying_unknown_reals'"

        return super().from_dataset(dataset, **new_kwargs)

from pytorch_lightning.utilities.model_summary import ModelSummary


model = FullyConnectedModel.from_dataset(training, hidden_size=10, n_hidden_layers=2)
print(ModelSummary(model, max_depth=-1))
model.hparams

pl.seed_everything(42)
trainer = pl.Trainer(accelerator="auto", gradient_clip_val=0.01, auto_select_gpus=True, gpus=-1,)

from pytorch_lightning.tuner.tuning import Tuner

res = Tuner(trainer).lr_find(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5)
print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
trainer = pl.Trainer(
    max_epochs=3,
    accelerator="auto",
    enable_model_summary=True,
    gradient_clip_val=0.01,
    callbacks=[early_stop_callback],
    limit_train_batches=150,
    auto_select_gpus=True,
    gpus=-1,
)

model = FullyConnectedModel.from_dataset(training, hidden_size=10, n_hidden_layers=2)

The code ends up breaking in the last line with the above mentioned error. Which is strange because it is the same way it is called to initialize the model prior to running lr_find() function for the model.

The stacktrace looks as follows:

RuntimeError                              Traceback (most recent call last)
Cell In [37], line 25
      2 trainer = pl.Trainer(
      3     max_epochs=3,
      4     accelerator="auto",
   (...)
     10     gpus=-1,
     11 )
     14 # model = FullyConnectedModel.from_dataset(
     15 #     training,
     16 #     learning_rate=1e-3,
   (...)
     22 # )
     23 # del model
---> 25 model_new = FullyConnectedModel.from_dataset(training, hidden_size=10, n_hidden_layers=2)
     26 # model = FullyConnectedClassificationModel.from_dataset(training, hidden_size=10, n_hidden_layers=2)
     27 
     28 
     29 # model.hparams.learning_rate = res.suggestion()
     31 trainer.fit(
     32     model_new,
     33     train_dataloaders=train_dataloader,
     34     val_dataloaders=val_dataloader,
     35 )

Cell In [27], line 53
     42 assert dataset.min_encoder_length == dataset.max_encoder_length, "Encoder only supports a fixed length"
     43 assert (
     44     len(dataset.time_varying_known_categoricals) == 0
     45     and len(dataset.time_varying_known_reals) == 0
   (...)
     50     and dataset.time_varying_unknown_reals[0] == dataset.target
     51 ), "Only covariate should be the target in 'time_varying_unknown_reals'"
---> 53 return super().from_dataset(dataset, **new_kwargs)

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\site-packages\pytorch_forecasting\models\base_model.py:996, in BaseModel.from_dataset(cls, dataset, **kwargs)
    994 if "output_transformer" not in kwargs:
    995     kwargs["output_transformer"] = dataset.target_normalizer
--> 996 net = cls(**kwargs)
    997 net.dataset_parameters = dataset.get_parameters()
    998 if dataset.multi_target:

Cell In [27], line 11
      9 self.save_hyperparameters()
     10 # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
---> 11 super().__init__(**kwargs)
     12 self.network = FullyConnectedModule(
     13     input_size=self.hparams.input_size,
     14     output_size=self.hparams.output_size,
     15     hidden_size=self.hparams.hidden_size,
     16     n_hidden_layers=self.hparams.n_hidden_layers,
     17 )

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\site-packages\pytorch_forecasting\models\base_model.py:261, in BaseModel.__init__(self, log_interval, log_val_interval, learning_rate, log_gradient_flow, loss, logging_metrics, reduce_on_plateau_patience, reduce_on_plateau_reduction, reduce_on_plateau_min_lr, weight_decay, optimizer_params, monotone_constaints, output_transformer, optimizer)
    259 frame = inspect.currentframe()
    260 init_args = get_init_args(frame)
--> 261 self.save_hyperparameters(
    262     {name: val for name, val in init_args.items() if name not in self.hparams and name not in ["self"]}
    263 )
    265 # update log interval if not defined
    266 if self.hparams.log_val_interval is None:

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\site-packages\pytorch_lightning\core\mixins\hparams_mixin.py:105, in HyperparametersMixin.save_hyperparameters(self, ignore, frame, logger, *args)
    103 if not frame:
    104     frame = inspect.currentframe().f_back
--> 105 save_hyperparameters(self, *args, ignore=ignore, frame=frame)

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\site-packages\pytorch_lightning\utilities\parsing.py:251, in save_hyperparameters(obj, ignore, frame, *args)
    249     obj._set_hparams(hp)
    250 # make deep copy so  there is not other runtime changes reflected
--> 251 obj._hparams_initial = copy.deepcopy(obj._hparams)

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:172, in deepcopy(x, memo, _nil)
    170                 y = x
    171             else:
--> 172                 y = _reconstruct(x, memo, *rv)
    174 # If is its own copy, don't memoize.
    175 if y is not x:

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:296, in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    294     for key, value in dictiter:
    295         key = deepcopy(key, memo)
--> 296         value = deepcopy(value, memo)
    297         y[key] = value
    298 else:

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:172, in deepcopy(x, memo, _nil)
    170                 y = x
    171             else:
--> 172                 y = _reconstruct(x, memo, *rv)
    174 # If is its own copy, don't memoize.
    175 if y is not x:

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:270, in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    268 if state is not None:
    269     if deep:
--> 270         state = deepcopy(state, memo)
    271     if hasattr(y, '__setstate__'):
    272         y.__setstate__(state)

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:146, in deepcopy(x, memo, _nil)
    144 copier = _deepcopy_dispatch.get(cls)
    145 if copier is not None:
--> 146     y = copier(x, memo)
    147 else:
    148     if issubclass(cls, type):

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:230, in _deepcopy_dict(x, memo, deepcopy)
    228 memo[id(x)] = y
    229 for key, value in x.items():
--> 230     y[deepcopy(key, memo)] = deepcopy(value, memo)
    231 return y

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:153, in deepcopy(x, memo, _nil)
    151 copier = getattr(x, "__deepcopy__", None)
    152 if copier is not None:
--> 153     y = copier(memo)
    154 else:
    155     reductor = dispatch_table.get(cls)

File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\site-packages\torch\_tensor.py:102, in Tensor.__deepcopy__(self, memo)
    100     return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
    101 if not self.is_leaf:
--> 102     raise RuntimeError(
    103         "Only Tensors created explicitly by the user "
    104         "(graph leaves) support the deepcopy protocol at the moment"
    105     )
    106 if id(self) in memo:
    107     return memo[id(self)]

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

I am using the code given in the official documentation for pytorch-forecasting library which can be found here: How to use custom data and implement custom models and metrics — pytorch-forecasting documentation

Hi all,

I compared the custom implementation with that of NBeats model implementation in pytorch-forecasting. Tried to mimic the init function in terms of other parameters besides the model specific parameters. I was able to resolve it by specifying the loss in the init method of the custom model as follows:

def __init__(self, input_size: int, output_size: int, hidden_size: int, n_hidden_layers: int,
        # learning_rate: float = 1e-2,
        # log_interval: int = -1,
        # log_gradient_flow: bool = False,
        # log_val_interval: int = None,
        # weight_decay: float = 1e-3,
        loss: MultiHorizonMetric = None,
        # reduce_on_plateau_patience: int = 1000,
        # backcast_loss_ratio: float = 0.0,
        logging_metrics: nn.ModuleList = None, 
        **kwargs):
        if logging_metrics is None:
            logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
        if loss is None:
            loss = MASE()
        # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
        self.save_hyperparameters()
        # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
        super().__init__(**kwargs)
        self.network = FullyConnectedModule(
            input_size=self.hparams.input_size,
            output_size=self.hparams.output_size,
            hidden_size=self.hparams.hidden_size,
            n_hidden_layers=self.hparams.n_hidden_layers,
        )