Is it possible to replace gradients with stale ones to mitigate communications

Here’s my thoughts:

  1. Inter-node communication is usually the bottleneck for multi-nodes DDP;
  2. For small bandwidth inter-node connection, the gradient all reducing operation could take even longer than forward pass (in extreme case);
  3. pytorch currently has some optimization on overlapping some gradient all reducing ops and backward computation.

And I wonder, if we use grad_t-1 (which was calculated after last step) instead of grad_t to update parameters, during forward pass of step t, the all reducing for grad_t-1 could run in parallel, so that we’re able to overlap the communications with both backward and (next step) forward passes, and mitigate the communication issue.

What I tried so far: add a custom hook via register_comm_hook(), where I impl a “queue” of gradient buckets, something like this

def custom_hook(state, gradbucket):
    # a torch.futures.Future object with .value() = previous gradients after all reduced.
    last_fut = state.pull(gradbucket.index())
    state.push(gradbucket)
    return last_fut

But it didn’t work, no any time benefits observed in my experiment. And since everything about grad reducer are wrapped as blackbox and in lack of documentation, it’s difficult to do debugging.

Any help or suggestion will be appreciated.