Hi, guys,
I want to use checkpoint
to reduce the GPU memory usage, and I have known that, checkpoint can be applied like,
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
So, if there is a big module has a nested custom module like,
class Net(nn.Module):
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.ConvBnRelu1(x)
out = self.ConvBnRelu2(x) # instance of ConvBnRelu
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
For best utilization of GPU memory, should checkpoint also be applied to the nested child modules of ConvBnRelu
?
Your answer and guide will be appreciated!