How to apply selective activation checkpointing on _grouped_mm

From the selective activation checkpointing blog, I understood that we can save compute intensive operations while recompute operations that are memory intensive. How do I add _grouped_mm to ops_to_save below?

Does directly adding torch.ops.aten.grouped_mm.default works or can I directly adding `torch._grouped_mm` does the trick ?

ops_to_save = [

torch.ops.aten.mm.default,

torch.ops.aten.bmm.default,

torch.ops.aten.addmm.default,

torch.ops.aten.grouped_mm.default,

# Save collectives that produce dynamic-shaped outputs

torch.ops._c10d_functional.all_to_all_single.default,

torch.ops._c10d_functional.reduce_scatter_tensor.default,

# for low precision training, it’s useful to always save

# the result of max, since the absolute maximum is

# used to compute the scaling factor for quantization.

torch.ops.aten.max.default,

]

def _apply_ac(model: torch.nn.Module):

from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper, CheckpointImpl

# Create a policy function that returns a CheckpointPolicy

def policy_fn(ctx, op, *args, **kwargs):

if op in ops_to_save:

return CheckpointPolicy.MUST_SAVE

else:

return CheckpointPolicy.PREFER_RECOMPUTE

“”“Apply activation checkpointing to the model.”“”

for layer_id, transformer_block in model.layers.named_children():

transformer_block = checkpoint_wrapper(

transformer_block,

checkpoint_impl=CheckpointImpl.NO_REENTRANT,

context_fn=partial(create_selective_checkpoint_contexts, policy_fn),

)

model.layers.register_module(layer_id, transformer_block)

logger.info(“Applied selective activation checkpointing to the model”)

I am using moe implementation from torch titan from reference, I wan