How does torch.fx work with activation checkpointing?

I am wondering how we can use torch.utils.checkpoint together with torch.fx.

Simple example is given below:

import torch
import torch.nn as nn
import torch.fx as fx
from torch.utils.checkpoint import checkpoint

class MyModule(torch.nn.Module):
    def __init__(self):
        self.linear1 = torch.nn.Linear(2, 2)
        self.linear2 = torch.nn.Linear(2, 2)
        self.linear3 = torch.nn.Linear(2, 2)
    def _transform(self, x):
        return x.transpose(1, 0)
    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x) + self.linear3(x)
        trans_x = self._transform(x)
        return trans_x
class NestedModule(torch.nn.Module):
    def __init__(self):
        self.my_mod = MyModule()
        self.linear4 = torch.nn.Linear(2, 2)
    def forward(self, x):
        x = checkpoint(self.my_mod, x)
        return self.linear4(x)
nested_mod = NestedModule()
gm = fx.symbolic_trace(nested_mod)

The output will ignore the checkpoint logic.

def forward(self, x):
    my_mod_linear1 = self.my_mod.linear1(x);  x = None
    my_mod_linear2 = self.my_mod.linear2(my_mod_linear1)
    my_mod_linear3 = self.my_mod.linear3(my_mod_linear1);  my_mod_linear1 = None
    add = my_mod_linear2 + my_mod_linear3;  my_mod_linear2 = my_mod_linear3 = None
    transpose = add.transpose(1, 0);  add = None
    linear4 = self.linear4(transpose);  transpose = None
    return linear4