I am training an ensemble of two deep learning models and the workflow is working fine and saving ‘.pth’ file when trained on M1 mac but on GPU I am getting this error. i tried changing version to 1.12.0 (torch), same error. I am writing full error and code snippet. Please assist. Thank you.
# save model function outside Trainer Class
def save_model(model, model_filename, path='.'):
if not model_filename.endswith('.pth'):
model_filename += '.pth'
filepath = os.path.join(path, model_filename)
io.create_path(filepath)
torch.save(model.state_dict(), filepath)
# save_model method inside Trainer class
def save_model(self, model_filename=None, *, models_dir='models'):
if model_filename is None and self.model_filename is None:
raise ValueError('Param "model_filename" is None')
elif model_filename is None:
model_filename = self.model_filename
path = os.path.join(self.path, models_dir)
save_model(self.model, model_filename, path)
# last part of training method where it saves the model
# update ReduceLROnPlateau scheduler (if available)
if (valid_loss is not None and scheduler is not None
and isinstance(scheduler, lr_scheduler.ReduceLROnPlateau)):
scheduler.step(valid_loss)
if self.primary_metric in scores:
# save model checkpoint
is_better = scores[self.primary_metric] > best_score
if is_better:
best_score = scores[self.primary_metric]
best_state_dict = deepcopy(self.model.state_dict())
# best_state_dict = self.model.state_dict()
if self.model_filename is not None:
print(f'Epoch {epoch+1} - Save Checkpoint with Best Recall Score: {best_score:.6f}')
self.save_model(self.model_filename, models_dir='models')
if best_state_dict is not None:
self.model.load_state_dict(best_state_dict)
Traceback (most recent call last):
File "/notebooks/ensamble_model/ensamble_models/fine_tune.py", line 129, in <module>
trainer.train(no_epochs=config.no_epochs, lr=config.learning_rate)
File "/notebooks/ensamble_model/ensamble_models/src/core/training.py", line 259, in train
best_state_dict = self.model.state_dict()
File "/notebooks/.env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1381, in state_dict
module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
TypeError: state_dict() got an unexpected keyword argument 'destination'