I see, one more question, will you move the model around, such as on a different machine with different gpu number, or are you loading the whole model on the same devices?
If you don’t, and you really want to save them seperately to different files, maybe for better inspection or archive perpose, then:
def save(your_model):
torch.save(your_model.fc1, "fc1.pt")
torch.save(your_model.fc2, "fc2.pt")
torch.save(your_model.fc3, "fc3.pt")
If you do, then you will have to decide which device each part of your model would locate on, eg: suppose on your training machine you have 3 gpus, and on your inference machine you have 1 gpu.
def save(your_model):
def save(your_model):
torch.save(your_model.fc1, "fc1.pt")
torch.save(your_model.fc2, "fc2.pt")
torch.save(your_model.fc3, "fc3.pt")
def map(your_model):
your_model.fc1 = torch.load("fc1.pt", map_location=torch.device('cuda:0'))
your_model.fc2 = torch.load("fc2.pt", map_location=torch.device('cuda:0'))
your_model.fc3 = torch.load("fc3.pt", map_location=torch.device('cuda:0'))
by the way,
Maybe you have some wrong idea, there is not such a “connected device” concept in pytorch, you can perform a complex forward() operation or a simple add() operation on some input x
locating on device cuda:[number]
or cpu
simply because the operands (tensors) locates on the same device, if torch needs to fetch it somewhere else, it will complain and throw an error.
About saving the model
There are many ways to save your model, typically you will want to save the OrderedDict
returned by model.state_dict()
, the keys are your parameter names such as “linear.weight” or “linear.bias”, and values are nn.Parameter
, its .data
attribute is just a Tensor. You may load a state dict into your model like:
def prep_load_state_dict(model: nn.Module,
state_dict: Any):
"""
Automatically load a **loaded state dictionary**
Note:
This function handles tensor device remapping.
"""
for name, param in model.named_parameters():
state_dict[name].to(param.device)
model.load_state_dict(state_dict)
About torch.save and torch.load
If you know the pickle
concept in python, then you will get what torch.save
does. pickle
serialize a object into binary string:
buffer = io.BytesIO()
t.save(t.zeros([5]), buffer)
print(buffer.getvalue())
will yield:
b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\n\x00\x00\x00type_sizesq\x01}q\x02(X\x03\x00\x00\x00intq\x03K\x04X\x04\x00\x00\x00longq\x04K\x04X\x05\x00\x00\x00shortq\x......
you can serialize whatever you like into this, cuda tensor will essentially be saved as “raw data” + “device descriptor cuda:0”.