Purpose of specifying the devices in model_state

My model has its tensors on cuda, and its state is saved as such. When I load the model state for inference with checkpoint = torch.load(weights_fpath), the model state keeps the same device configuration. This adds overhead to torch.load (in my case 1-2 seconds). Calling checkpoint = torch.load(weights_fpath, "cpu") takes around 1 millisecond however. I don’t really understand the point to all of this, because when I’ll call model.load_state_dict(checkpoint["model_state"]) after, all my tensors are going to end up on cpu, regardless of the device specified in the model state.

  1. What’s the point of setting the device of the tensors in the model state, what’s the use of the map_location argument of torch.load?

  2. Is there a way to directly load cuda tensors when doing model.load_state_dict() that would be faster than loading cpu tensors then calling model.cuda()?

  1. The map_location argument remaps the tensors to the specified device. This is useful, if your current machine does not have a GPU and you are trying to load a tensor, which was saved as a CUDATensor.

From the docs:

torch.load() uses Python’s unpickling facilities but treats storages, which underlie tensors, specially. They are first deserialized on the CPU and are then moved to the device they were saved from. If this fails (e.g. because the run time system doesn’t have certain devices), an exception is raised. However, storages can be dynamically remapped to an alternative set of devices using the map_location argument.

  1. You could push the model to the GPU before loading the state_dict. However, the tensors have to be pushed somewhere to the device, so I guess you won’t save any time.
1 Like