Saving a models state_dict into redis cache

I’m building a distributed parameter/server type architecture and wanting to communicate model updates through table solutions on Azure.

I’m having a hard time finding any useful information about saving a models state_dict into a redis cache. I’ve given up on Azure Cosmos tables because of the size limit (64kb) per entity and looked toward redis since model state_dict params/weights are much larger, even for a small model.

Does anyone have any recommendations for me on how to pursue this?

Is this question about 1) whether Redis is an appropriate storage to save model states or 2) how to configure Azure to run Redis or 3) how to build parameter server using PyTorch?

Azure is simply the platform I’m developing on. I am looking for the answer to 1) Is redis an appropriate storage to save model parameters and weights?

I’ve recently learned about redisAI but it does not have an Azure equivalent service and would have to be deployed on a dedicated VM.

Hmm, isn’t this mainly depend on the data size and IO pattern? Or does being a DNN model make any difference?

I am curious why are you communicating model updates via external DB. Normally model updates are communicated via collective communication ranks or something like EASGD (https://arxiv.org/abs/1412.6651). Is your goal: debugging, logging, or improved reliability here? Seems like updating model via external DB would be a performance hit?

I’m testing out parallelizing across multiple worker nodes in a parameter-server type architecture. I’m using redisAI to handle the model weight and gradient sharing between primary and worker nodes. At each node the worker retains it’s own parameter set and delivers gradients to the primary node, which updates the global model. I have four workers, so I combine each workers update together

Sequentially performing worker updates, global update, and then worker reads of updated global model is a performance hit. I’m also testing out a update-and-continue scheme where the worker will push it’s gradients to the global model and then continue with it’s own path instead of adjusting to the global model.

beta = 0.25
gmsd = model.state_dict()
for name, param in model.named_parameters():
            worker_001_data = redisai_conn.tensorget(f'worker_001:{name}_grad')
            worker_002_data = redisai_conn.tensorget(f'worker_002:{name}_grad')
            worker_003_data = redisai_conn.tensorget(f'worker_003:{name}_grad')
            worker_004_data = redisai_conn.tensorget(f'worker_004:{name}_grad')
            tens = worker_001_data*beta + worker_002_data*beta + worker_003_data*beta + worker_004_data*beta
            worker_ten = torch.from_numpy(tens).to(self.device)
            if gmsd[name].grad == None:
                gmsd[name].grad = (worker_ten)
            else:
                gmsd[name].grad.copy_(worker_ten)

model.load_state_dict(gmsd)

My goal is improved reliability in the global models predictions.

I see, for this use case, an alternative is to use torch.distributed.rpc to connect the parameter server with trainers, and then let the parameter server periodically flush checkpoints to the external storage. So that you don’t have to pay the checkpointing overhead in every iteration.

Some related resources:

  1. Building HogWild! PS using torch.distributed.rpc: https://pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html
  2. Batch updating PS (requires v1.6+): https://github.com/pytorch/tutorials/blob/release/1.6/intermediate_source/rpc_async_execution.rst