Oh right, I forgot about __torch_dispatch__
. (I lack a bit the experience on what I can hock/dispatch.)
Yea, via this mechanism, I could exactly do what I need. That seems to be great!
Note, in RETURNN, for the TensorFlow backend, I have implemented this API for gradient checkpointing, where you can do sth like this:
with gradient_checkpoint_scope():
x = a + b
y = x * c
All tensors in gradient_checkpoint_scope
will get specially marked. This is using TF graph mode. I just need to record the first node/op id and the last. And then I can iterate through all ops from the scope like this:
for op_id in range(first_op_id, last_op_id):
op = graph._nodes_by_id[op_id]
...
And then I overwrite the grad func of all consumers:
for op_out in op.outputs:
assert isinstance(op_out, tf.Tensor)
for op_ in op_out.consumers():
_set_wrapped_grad_func(op_)
This is before I calculate the gradient (tf.gradients
, similar as Torch tensor.backward()
).
In my wrapped gradient function, instead of passing any of the original tensors from that gradient_checkpoint_scope
to the original gradient function, I copy all the computation graph under this scope and pass the copied tensors. I don’t really need to care about adding some unnecessary computations to the computation graph, as TensorFlow only lazily calculates whatever is needed in the end.
So, I wonder if I can implement the same logic in PyTorch, i.e. the same gradient_checkpoint_scope
API. Very similar to your example, I could use __torch_dispatch__
to record the computation graph and also record all tensors which are created there. If I would use saved_tensors_hooks
on the whole model, I can check whenever some of the tensors come from such checkpoint scope, and if so, recompute that.
If I don’t want to have saved_tensors_hooks
on the whole model, I can solve this too: I can saved_tensors_hooks.__enter__
when I enter the checkpoint scope, but don’t immediately saved_tensors_hooks.__exit__
afterwards. Instead, I can hook into Tensor.__del__
(as described above) for all tensors created in that scope, and once they all get deleted, then I can do saved_tensors_hooks.__exit__
(and also exit the __torch_dispatch__
). So this stays very local then.
It gets a bit trickier to handle the RNG state correctly, and maybe also handle AMP. I guess I need to make sure to store and then later fork and reset the state correctly, and also to replay the recorded computations exactly in the same order to make sure it’s all deterministic. But this is similar to the current Torch checkpoint
logic.
So, this sounds like a plan…