We are building a library where many large components can be used within each other e.g. a Transformer model can be used inside a UNet layer.
To save memory for training we use gradient checkpointing in both components. When use the transformer model inside the unet layer, we essentially use gradient checkpointing inside a function on which gradient checkpointing was already applied.
Is this problematic? The training seems to work just fine. Here some pseudo code:
class UNetModel(nn.Module):
self.layers = [TransformerModel() for _ in range(10)]
def forward(self, x):
for layer in self.layers:
x = torch.utils.checkpoint.checkpoint(layer, x, ...)
class TransformerModel(nn.Module):
self.layers = [nn.TransformerLayer() for _ in range(10)]
def forward(self, x):
for layer in self.layers:
x = torch.utils.checkpoint.checkpoint(layer, x, ...)
Here you can see that gradient checkpointing is used for every transformer layer even though the whole transformer model has already gradient checkpointing. The reason we do this is because we can each component (UNet and Transformer) to be as memory efficient as possible out of the box even when used with other components.
=> Is it problematic to have a gradient checkpointing function inside another one?