I would like to share a pytorch model between two processes. I have a trainer and an application (inference) module as separate processes running. I have implemented model.save/load only when the trainer is writing and the inferences module tries to read, it gives an error (obviously).
I am looking at this forum and everywhere and it drives me crazy with complex solutions for training with multiple processes etc which is not what I want. I am really wondering how hard it can be to simple share the state dict between two processes.
One attempt was using SharedMemoryManager, here is an example of what I wrote:
with SharedMemoryManager() as shared_memory_manager:
# Initialize shared memory
shared_memory = dict()
for key in shared_memory_settings.keys():
shape = shared_memory_settings[key]["shape"]
dtype = shared_memory_settings[key]["dtype"]
shared_memory_settings[key]["shm"] = shared_memory_manager.SharedMemory(size=np.zeros(**shared_memory_settings[key]).nbytes)
shared_memory_settings["state_dict"] = shared_memory_manager.SharedMemory(size=sys.getsizeof(model.model.state_dict()))
# Initialize and start applications
p1 = Process(target=Trainer, args=(shared_memory_settings, deepcopy(model), model_weights_path))
p2 = Process(target=Application, args=(dataset, model, model_weights_path, shared_memory_settings))
p1.start()
p2.start()
p1.join()
p2.join()
I am trying to use SharedMemory which works fine with numpy arrays but it gives me a difficult time for sharing the state dict.
So my question is, what is the best practise for a very simple problem which is using model weights for inference while a seperate independent process is training and updating these weights on a regular basis.
I have looked into queues but on a first glance it looks very complicated and overcomplicated things, but perhaps that is the way I should go?