Hello!
I was thinking that it could be useful to save metadata about a model when you export it via torchscript. Example of such metadata could e.g. be the git-hash that the model was trained under.
Some example code
import torch
import torch.nn as nn
class Mymodel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(10, 10, 1, 1, 0, bias=True)
self.accuracy = 0
@torch.jit.export
def model_info(self):
return {
"git_hash": "*git-hash-info*",
"created": "*timestamp*",
"accuracy": f"{self.accuracy}",
"pytorch_version": "*version*",
}
def forward(self, x):
x = self.conv(x)
return x
model = Mymodel()
model.accuracy = 10
scripted_model = torch.jit.script(model)
model_path = 'testmodel.pt'
scripted_model.save(model_path)
loaded_model = torch.jit.load(model_path)
print(loaded_model)
print(loaded_model.model_info())
I think the practice could really help track your models in many systems. I havenโt seen the technique used before and would love some arguments against it or suggestion improvements.
What kind of information would be useful to embed in a torchscript model apart from what Iโve written above?