I am wondering how we can use torch.utils.checkpoint
together with torch.fx
.
Simple example is given below:
import torch
import torch.nn as nn
import torch.fx as fx
from torch.utils.checkpoint import checkpoint
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(2, 2)
self.linear2 = torch.nn.Linear(2, 2)
self.linear3 = torch.nn.Linear(2, 2)
def _transform(self, x):
return x.transpose(1, 0)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x) + self.linear3(x)
trans_x = self._transform(x)
return trans_x
class NestedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_mod = MyModule()
self.linear4 = torch.nn.Linear(2, 2)
def forward(self, x):
x = checkpoint(self.my_mod, x)
return self.linear4(x)
nested_mod = NestedModule()
gm = fx.symbolic_trace(nested_mod)
gm.recompile()
print(gm)
The output will ignore the checkpoint logic.
def forward(self, x):
my_mod_linear1 = self.my_mod.linear1(x); x = None
my_mod_linear2 = self.my_mod.linear2(my_mod_linear1)
my_mod_linear3 = self.my_mod.linear3(my_mod_linear1); my_mod_linear1 = None
add = my_mod_linear2 + my_mod_linear3; my_mod_linear2 = my_mod_linear3 = None
transpose = add.transpose(1, 0); add = None
linear4 = self.linear4(transpose); transpose = None
return linear4