Torch.utils.checkpoint.checkpoint

Hi,

You need to modify the forward pass to replace how you use the corresponding submodule:

# Original:
out = self.my_block(inp1. inp2. inp3)

# With checkpointing:
out = checkpoint(self.my_block, inp1, inp2, inp3)
3 Likes