From the selective activation checkpointing blog, I understood that we can save compute intensive operations while recompute operations that are memory intensive. How do I add _grouped_mm to ops_to_save below?
Does directly adding torch.ops.aten.grouped_mm.default works or can I directly adding `torch._grouped_mm` does the trick ?
ops_to_save = [
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.addmm.default,
torch.ops.aten.grouped_mm.default,
# Save collectives that produce dynamic-shaped outputs
torch.ops._c10d_functional.all_to_all_single.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
# for low precision training, it’s useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
]
def _apply_ac(model: torch.nn.Module):
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper, CheckpointImpl
# Create a policy function that returns a CheckpointPolicy
def policy_fn(ctx, op, *args, **kwargs):
if op in ops_to_save:
return CheckpointPolicy.MUST_SAVE
else:
return CheckpointPolicy.PREFER_RECOMPUTE
“”“Apply activation checkpointing to the model.”“”
for layer_id, transformer_block in model.layers.named_children():
transformer_block = checkpoint_wrapper(
transformer_block,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
context_fn=partial(create_selective_checkpoint_contexts, policy_fn),
)
model.layers.register_module(layer_id, transformer_block)
logger.info(“Applied selective activation checkpointing to the model”)
I am using moe implementation from torch titan from reference, I wan