Hi,
You need to modify the forward pass to replace how you use the corresponding submodule:
# Original:
out = self.my_block(inp1. inp2. inp3)
# With checkpointing:
out = checkpoint(self.my_block, inp1, inp2, inp3)
Hi,
You need to modify the forward pass to replace how you use the corresponding submodule:
# Original:
out = self.my_block(inp1. inp2. inp3)
# With checkpointing:
out = checkpoint(self.my_block, inp1, inp2, inp3)