For best utilization of GPU memory, should checkpoint also be applied to the nested child modules?

Hi, guys,
I want to use checkpoint to reduce the GPU memory usage, and I have known that, checkpoint can be applied like,

def forward(self, x): 
 def _inner_forward(x): 
        identity = x 
        out = self.conv1(x) 
        out = self.norm1(out) 
        out = self.relu(out) 
        out = self.conv2(out) 
        out = self.norm2(out) 
        if self.downsample is not None: 
            identity = self.downsample(x) 
        out += identity 
        return out 
 
    if self.with_cp and x.requires_grad: 
        out = cp.checkpoint(_inner_forward, x) 
    else: 
        out = _inner_forward(x) 
    out = self.relu(out) 
    return out

So, if there is a big module has a nested custom module like,

class Net(nn.Module):
 def forward(self, x): 
  def _inner_forward(x): 
         identity = x 
         out = self.ConvBnRelu1(x) 
         out = self.ConvBnRelu2(x)  # instance of ConvBnRelu
         if self.downsample is not None: 
             identity = self.downsample(x) 
         out += identity 
         return out 
 
     if self.with_cp and x.requires_grad: 
         out = cp.checkpoint(_inner_forward, x) 
     else: 
         out = _inner_forward(x) 
     out = self.relu(out) 
     return out

For best utilization of GPU memory, should checkpoint also be applied to the nested child modules of ConvBnRelu?

Your answer and guide will be appreciated!