I am trying to create a custom model in pytorch but after running lr_find on the model I am seeing below error.
“RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment”
I am not able to figure out why running lr_find would prevent the model from getting reinitialized using the .from_dataset() method of the model itself. Given that the initialization is happining without any issues prior to doing so. If I dont use lr_find and simply proceed to training, i face no issues in training the model.
Below is the code:
import torch
from torch import nn
class FullyConnectedModule(nn.Module):
def __init__(self, input_size: int, output_size: int, hidden_size: int, n_hidden_layers: int):
super().__init__()
# input layer
module_list = [nn.Linear(input_size, hidden_size), nn.ReLU()]
# hidden layers
for _ in range(n_hidden_layers):
module_list.extend([nn.Linear(hidden_size, hidden_size), nn.ReLU()])
# output layer
module_list.append(nn.Linear(hidden_size, output_size))
self.sequential = nn.Sequential(*module_list)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x of shape: batch_size x n_timesteps_in
# output of shape batch_size x n_timesteps_out
return self.sequential(x)
from typing import Dict
from pytorch_forecasting.models import BaseModel
class FullyConnectedModel(BaseModel):
def __init__(self, input_size: int, output_size: int, hidden_size: int, n_hidden_layers: int, **kwargs):
# saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this
self.save_hyperparameters()
# pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
super().__init__(**kwargs)
self.network = FullyConnectedModule(
input_size=self.hparams.input_size,
output_size=self.hparams.output_size,
hidden_size=self.hparams.hidden_size,
n_hidden_layers=self.hparams.n_hidden_layers,
)
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# x is a batch generated based on the TimeSeriesDataset
network_input = x["encoder_cont"].squeeze(-1)
prediction = self.network(network_input).unsqueeze(-1)
# rescale predictions into target space
prediction = self.transform_output(prediction, target_scale=x["target_scale"])
# We need to return a dictionary that at least contains the prediction.
# The parameter can be directly forwarded from the input.
# The conversion to a named tuple can be directly achieved with the `to_network_output` function.
return self.to_network_output(prediction=prediction)
@classmethod
def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
new_kwargs = {
"output_size": dataset.max_prediction_length,
"input_size": dataset.max_encoder_length,
}
new_kwargs.update(kwargs) # use to pass real hyperparameters and override defaults set by dataset
# example for dataset validation
assert dataset.max_prediction_length == dataset.min_prediction_length, "Decoder only supports a fixed length"
assert dataset.min_encoder_length == dataset.max_encoder_length, "Encoder only supports a fixed length"
assert (
len(dataset.time_varying_known_categoricals) == 0
and len(dataset.time_varying_known_reals) == 0
and len(dataset.time_varying_unknown_categoricals) == 0
and len(dataset.static_categoricals) == 0
and len(dataset.static_reals) == 0
and len(dataset.time_varying_unknown_reals) == 1
and dataset.time_varying_unknown_reals[0] == dataset.target
), "Only covariate should be the target in 'time_varying_unknown_reals'"
return super().from_dataset(dataset, **new_kwargs)
from pytorch_lightning.utilities.model_summary import ModelSummary
model = FullyConnectedModel.from_dataset(training, hidden_size=10, n_hidden_layers=2)
print(ModelSummary(model, max_depth=-1))
model.hparams
pl.seed_everything(42)
trainer = pl.Trainer(accelerator="auto", gradient_clip_val=0.01, auto_select_gpus=True, gpus=-1,)
from pytorch_lightning.tuner.tuning import Tuner
res = Tuner(trainer).lr_find(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5)
print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
trainer = pl.Trainer(
max_epochs=3,
accelerator="auto",
enable_model_summary=True,
gradient_clip_val=0.01,
callbacks=[early_stop_callback],
limit_train_batches=150,
auto_select_gpus=True,
gpus=-1,
)
model = FullyConnectedModel.from_dataset(training, hidden_size=10, n_hidden_layers=2)
The code ends up breaking in the last line with the above mentioned error. Which is strange because it is the same way it is called to initialize the model prior to running lr_find() function for the model.
The stacktrace looks as follows:
RuntimeError Traceback (most recent call last)
Cell In [37], line 25
2 trainer = pl.Trainer(
3 max_epochs=3,
4 accelerator="auto",
(...)
10 gpus=-1,
11 )
14 # model = FullyConnectedModel.from_dataset(
15 # training,
16 # learning_rate=1e-3,
(...)
22 # )
23 # del model
---> 25 model_new = FullyConnectedModel.from_dataset(training, hidden_size=10, n_hidden_layers=2)
26 # model = FullyConnectedClassificationModel.from_dataset(training, hidden_size=10, n_hidden_layers=2)
27
28
29 # model.hparams.learning_rate = res.suggestion()
31 trainer.fit(
32 model_new,
33 train_dataloaders=train_dataloader,
34 val_dataloaders=val_dataloader,
35 )
Cell In [27], line 53
42 assert dataset.min_encoder_length == dataset.max_encoder_length, "Encoder only supports a fixed length"
43 assert (
44 len(dataset.time_varying_known_categoricals) == 0
45 and len(dataset.time_varying_known_reals) == 0
(...)
50 and dataset.time_varying_unknown_reals[0] == dataset.target
51 ), "Only covariate should be the target in 'time_varying_unknown_reals'"
---> 53 return super().from_dataset(dataset, **new_kwargs)
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\site-packages\pytorch_forecasting\models\base_model.py:996, in BaseModel.from_dataset(cls, dataset, **kwargs)
994 if "output_transformer" not in kwargs:
995 kwargs["output_transformer"] = dataset.target_normalizer
--> 996 net = cls(**kwargs)
997 net.dataset_parameters = dataset.get_parameters()
998 if dataset.multi_target:
Cell In [27], line 11
9 self.save_hyperparameters()
10 # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
---> 11 super().__init__(**kwargs)
12 self.network = FullyConnectedModule(
13 input_size=self.hparams.input_size,
14 output_size=self.hparams.output_size,
15 hidden_size=self.hparams.hidden_size,
16 n_hidden_layers=self.hparams.n_hidden_layers,
17 )
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\site-packages\pytorch_forecasting\models\base_model.py:261, in BaseModel.__init__(self, log_interval, log_val_interval, learning_rate, log_gradient_flow, loss, logging_metrics, reduce_on_plateau_patience, reduce_on_plateau_reduction, reduce_on_plateau_min_lr, weight_decay, optimizer_params, monotone_constaints, output_transformer, optimizer)
259 frame = inspect.currentframe()
260 init_args = get_init_args(frame)
--> 261 self.save_hyperparameters(
262 {name: val for name, val in init_args.items() if name not in self.hparams and name not in ["self"]}
263 )
265 # update log interval if not defined
266 if self.hparams.log_val_interval is None:
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\site-packages\pytorch_lightning\core\mixins\hparams_mixin.py:105, in HyperparametersMixin.save_hyperparameters(self, ignore, frame, logger, *args)
103 if not frame:
104 frame = inspect.currentframe().f_back
--> 105 save_hyperparameters(self, *args, ignore=ignore, frame=frame)
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\site-packages\pytorch_lightning\utilities\parsing.py:251, in save_hyperparameters(obj, ignore, frame, *args)
249 obj._set_hparams(hp)
250 # make deep copy so there is not other runtime changes reflected
--> 251 obj._hparams_initial = copy.deepcopy(obj._hparams)
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:172, in deepcopy(x, memo, _nil)
170 y = x
171 else:
--> 172 y = _reconstruct(x, memo, *rv)
174 # If is its own copy, don't memoize.
175 if y is not x:
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:296, in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
294 for key, value in dictiter:
295 key = deepcopy(key, memo)
--> 296 value = deepcopy(value, memo)
297 y[key] = value
298 else:
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:172, in deepcopy(x, memo, _nil)
170 y = x
171 else:
--> 172 y = _reconstruct(x, memo, *rv)
174 # If is its own copy, don't memoize.
175 if y is not x:
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:270, in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
268 if state is not None:
269 if deep:
--> 270 state = deepcopy(state, memo)
271 if hasattr(y, '__setstate__'):
272 y.__setstate__(state)
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:146, in deepcopy(x, memo, _nil)
144 copier = _deepcopy_dispatch.get(cls)
145 if copier is not None:
--> 146 y = copier(x, memo)
147 else:
148 if issubclass(cls, type):
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:230, in _deepcopy_dict(x, memo, deepcopy)
228 memo[id(x)] = y
229 for key, value in x.items():
--> 230 y[deepcopy(key, memo)] = deepcopy(value, memo)
231 return y
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\copy.py:153, in deepcopy(x, memo, _nil)
151 copier = getattr(x, "__deepcopy__", None)
152 if copier is not None:
--> 153 y = copier(memo)
154 else:
155 reductor = dispatch_table.get(cls)
File c:\Users\pulah\anaconda3\envs\modern_ts_new\lib\site-packages\torch\_tensor.py:102, in Tensor.__deepcopy__(self, memo)
100 return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
101 if not self.is_leaf:
--> 102 raise RuntimeError(
103 "Only Tensors created explicitly by the user "
104 "(graph leaves) support the deepcopy protocol at the moment"
105 )
106 if id(self) in memo:
107 return memo[id(self)]
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
I am using the code given in the official documentation for pytorch-forecasting library which can be found here: How to use custom data and implement custom models and metrics — pytorch-forecasting documentation