I am using Hogwild type training with multiprocessing. Inside the model, there is a step which clips the gradients:
kalman_state["mean"].register_hook(
lambda g: g.clamp(-magnitude, magnitude))
However when I added this clipping, the gradients would explode to extremely large numbers (>1e10).
Adding a mutex solved the problem:
self.register_mutex.acquire()
kalman_state["mean"].register_hook(
lambda g: g.clamp(-magnitude, magnitude))
self.register_mutex.release()
So I presume running register_hook simultaneously in two threads is causing issues. If this is the case, is this a bug or a known limitation?