I have a model(.pth file) trained on pytorch version 2.x but for the purpose of quantization, I need to convert the same model to pytorch version 1.x is there a way to do this?
Could you describe what your use case is? The common use case is to store the state_dict which contains the parameters and buffers and should generally be compatible between various PyTorch releases.
Use case: For the purpose of deploying the model on hardware. The hardware requires quantized model and model which is to be quantized needs .pth file in pytorch 1.x version. So, even if I train the in pytorch 2.x environment, it needs to be downgraded to 1.x version, can this happen or not?
Here is the code which I am currently using to downgrade the model to 1.x version.
import torch
import torch.nn as nn
import copy
import os
import sys
from config import cfg
from lib.models.movenet_mobilenetv2 import MoveNet
def downgrade_model_to_pytorch_1_8():
print("Starting PyTorch model downgrade process...")
\# Paths
original_model_path = r"C:\\Users\\Z9IMKS3\\Downloads\\e120_valacc0_89439_tb_gl_ws.pth"
downgraded_model_path = r"C:\\Users\\Z9IMKS3\\Downloads\\e120_valacc0_89439_tb_gl_ws_352_352.pth"
print(f"Original model path: {original_model_path}")
print(f"Target downgraded model path: {downgraded_model_path}")
\# Step 1: Load the original model weights
print("\\nStep 1: Loading original model weights...")
try:
\# Use pickle_module=None to avoid advanced serialization features
state_dict = torch.load(original_model_path, map_location='cpu')
if not isinstance(state_dict, dict):
print(" Loaded object is a full model, extracting state_dict")
state_dict = state_dict.state_dict()
print(" Original model weights loaded successfully")
except Exception as e:
print(f" Error loading original model: {e}")
return
\# Step 2: Check for incompatible operations/parameters
print("\\nStep 2: Checking for PyTorch 1.8.0 incompatibilities...")
\# Known compatibility issues that need transformation
incompatible_ops = {
\# Example: 'new_op_name': 'old_op_equivalent'
'silu': 'sigmoid', # SiLU/Swish activation was added in later versions
'scaled_dot_product_attention': None, # Added in PyTorch 2.0
\# Add more as needed
}
\# Check state dict for any known problematic keys or operations
problematic_keys = \[\]
for key, tensor in state_dict.items():
\# Look for specific patterns in key names that might indicate incompatible ops
for incompatible_op in incompatible_ops.keys():
if incompatible_op in key:
problematic_keys.append(key)
print(f" Found potentially incompatible operation in key: {key}")
if problematic_keys:
print(f" Warning: Found {len(problematic_keys)} potentially incompatible keys")
else:
print(" No known incompatibilities detected in state_dict keys")
\# Step 3: Create a fresh PyTorch 1.8 compatible model
print("\\nStep 3: Creating fresh PyTorch 1.8.0 compatible model...")
try:
\# Use the exact same architecture and parameters as original model
pytorch1_8_model = MoveNet(num_keypoints=cfg\['num_keypoints'\], mode='train', width_mult=1.75)
print(" PyTorch 1.8.0 compatible model created successfully")
except Exception as e:
print(f" Error creating PyTorch 1.8.0 compatible model: {e}")
return
\# Step 4: Transfer and adapt parameters
print("\\nStep 4: Transferring parameters to PyTorch 1.8.0 model...")
try:
\# For any problematic keys, we would transform them here
\# For this example, we'll assume direct transfer works
pytorch1_8_model.load_state_dict(state_dict)
print(" Parameters transferred successfully")
except Exception as e:
print(f" Error transferring parameters: {e}")
\# Show mismatched keys for debugging
target_keys = set(pytorch1_8_model.state_dict().keys())
source_keys = set(state_dict.keys())
missing_keys = target_keys - source_keys
unexpected_keys = source_keys - target_keys
if missing_keys:
print(" Missing keys in source state_dict:")
for key in sorted(missing_keys):
print(f" {key}")
if unexpected_keys:
print(" Unexpected keys in source state_dict:")
for key in sorted(unexpected_keys):
print(f" {key}")
return
\# Step 5: Save with PyTorch 1.8.0 serialization format
print("\\nStep 5: Saving downgraded model...")
try:
\# Use protocol=2 for better backward compatibility
torch.save(pytorch1_8_model.state_dict(), downgraded_model_path, \_use_new_zipfile_serialization=False)
print(" Model saved with PyTorch 1.8.0 compatible format")
except Exception as e:
print(f" Error saving downgraded model: {e}")
return
\# Step 6: Verify the downgraded model
print("\\nStep 6: Verifying downgraded model...")
try:
verification_dict = torch.load(downgraded_model_path, map_location='cpu')
print(" Downgraded model loads successfully")
\# Validate model structure
test_model = MoveNet(num_keypoints=cfg\['num_keypoints'\], mode='test', width_mult=1.75)
test_model.load_state_dict(verification_dict)
test_model.eval()
\# Run a quick inference test
dummy_input = torch.randn(1, 3, 352, 352)
with torch.no_grad():
try:
test_model(dummy_input)
print(" Model passes basic inference test")
except Exception as e:
print(f" Warning: Model failed inference test: {e}")
print("\\nDowngrade completed successfully!")
print(f"Downgraded model saved to: {downgraded_model_path}")
except Exception as e:
print(f" Error verifying downgraded model: {e}")
return
if _name_ == “_main_”:
downgrade_model_to_pytorch_1_8()
No reply so far for the above problem @ptrblck
No, I don’t think PyTorch provides tools to store data using an old format since there are no promises for forward compatibility (i.e. running a new model using an old PyTorch release).