Persisting benchmark cache

When using torch.backends.cudnn.benchmark, we can select the fastest operator for conv.
Inside the pytorch, cache is used to store the pre-searched fastest algorithms.
But the thing is that our model’s input shape changes a lot and will cause frequent cudnn search which slows down the entire speed. So we figure if we can persist some cache to improve the hit rates of cache while still using the fastest implementation.

A simple way is like this:

template <typename T>
struct BenchmarkCache {
  std::mutex mutex;
  std::unordered_map<ConvolutionParams, T, ParamsHash<ConvolutionParams>, ParamsEqual<ConvolutionParams>> map;
  std::string CACHE_STORED_PATH;
  BenchmarkCache(std::string cache_stored_path): CACHE_STORED_PATH(cache_stored_path){

  bool find(const ConvolutionParams& params, T* results) {
    std::lock_guard<std::mutex> guard(mutex);
    auto it = map.find(params);
    if (it == map.end()) {
      return false;
    *results = it->second;
    return true;

  void insert(const ConvolutionParams& params, const T& results) {
    std::lock_guard<std::mutex> guard(mutex);
    map[params] = results;

  int load(){
    std::lock_guard<std::mutex> guard(mutex);
    std::ifstream i_cache(CACHE_STORED_PATH, ios::in);
    //load the mapping
    return 0;

  int save(){
    std::lock_guard<std::mutex> guard(mutex);
    std::ofstream o_cache(CACHE_STORED_PATH, ios::out);
    //save the mapping
    return 0;


Of course, there are a few issues need to work out.

  • The time of load and save, where should we put them. Are there some places we can insert the load() or save() call for efficiency.
  • One file per model or a global cache, other models can also use it.
  • The lock for load and save. we may have lots of threads try to read or write data.
  • How to load and save data with less time and space cost. e.g. It seems pytorch will calculate the hash of ConvolutionParams, can we store the hash value instead and make the cache map from a size_t value(acquired from ConvolutionParams) to cudnnConvolutionAlgoPerf.