It’s for memory optimization. In pipeline parallelism, some output tensors to be sent can be released. However, .backward()
will check the tensor shapes between the argument tensors
and grad_tensors
. To avoid this shape checking, Megatron-LM directly uses this API as discussed in the annotation of this function: Megatron-LM/megatron/core/pipeline_parallel/schedules.py at e33c8f78a35765d5aa37475a144da60e8a2349d1 · NVIDIA/Megatron-LM · GitHub
As a concrete example, you can uncomment the two snippets in the original demo code
#assert z._base is None
#z.data = torch.empty((1,), device=z.device, dtype=z.dtype,)
Then, you cannot run with .backward
, and pytorch will report some error like incompatible shape between z
and dz