How to change dynamic model to onnx?

I want to change Extremenet to onnx, which programmed a dynamic model, I searched
1 https://pytorch.org/docs/stable/onnx.html#tracing-vs-scripting
2 https://github.com/onnx/tutorials
It seems not told how to change a dynamic pytorch model to onnx, where is the example of change dynamic model to onnx?
below is core dynamic code patch:

class exkp(nn.Module):
def __init__(
    self, n, nstack, dims, modules, out_dim, pre=None, cnv_dim=256, 
    make_tl_layer=None, make_br_layer=None,
    make_cnv_layer=make_cnv_layer, make_heat_layer=make_kp_layer,
    make_tag_layer=make_kp_layer, make_regr_layer=make_kp_layer,
    make_up_layer=make_layer, make_low_layer=make_layer, 
    make_hg_layer=make_layer, make_hg_layer_revr=make_layer_revr,
    make_pool_layer=make_pool_layer, make_unpool_layer=make_unpool_layer,
    make_merge_layer=make_merge_layer, make_inter_layer=make_inter_layer, 
    kp_layer=residual
):
    super(exkp, self).__init__()
    self.nstack    = nstack
    self._decode   = _exct_decode

    curr_dim = dims[0]

    self.pre = nn.Sequential(
        convolution(7, 3, 128, stride=2),
        residual(3, 128, 256, stride=2)
    ) if pre is None else pre

    self.kps  = nn.ModuleList([
        kp_module(
            n, dims, modules, layer=kp_layer,
            make_up_layer=make_up_layer,
            make_low_layer=make_low_layer,
            make_hg_layer=make_hg_layer,
            make_hg_layer_revr=make_hg_layer_revr,
            make_pool_layer=make_pool_layer,
            make_unpool_layer=make_unpool_layer,
            make_merge_layer=make_merge_layer
        ) for _ in range(nstack)
    ])
    self.cnvs = nn.ModuleList([
        make_cnv_layer(curr_dim, cnv_dim) for _ in range(nstack)
    ])

    ## keypoint heatmaps
    self.t_heats = nn.ModuleList([
        make_heat_layer(cnv_dim, curr_dim, out_dim) for _ in range(nstack)
    ])

    self.l_heats = nn.ModuleList([
        make_heat_layer(cnv_dim, curr_dim, out_dim) for _ in range(nstack)
    ])

    self.b_heats = nn.ModuleList([
        make_heat_layer(cnv_dim, curr_dim, out_dim) for _ in range(nstack)
    ])

    self.r_heats = nn.ModuleList([
        make_heat_layer(cnv_dim, curr_dim, out_dim) for _ in range(nstack)
    ])

    self.ct_heats = nn.ModuleList([
        make_heat_layer(cnv_dim, curr_dim, out_dim) for _ in range(nstack)
    ])

    for t_heat, l_heat, b_heat, r_heat, ct_heat in \
      zip(self.t_heats, self.l_heats, self.b_heats, \
          self.r_heats, self.ct_heats):
        t_heat[-1].bias.data.fill_(-2.19)
        l_heat[-1].bias.data.fill_(-2.19)
        b_heat[-1].bias.data.fill_(-2.19)
        r_heat[-1].bias.data.fill_(-2.19)
        ct_heat[-1].bias.data.fill_(-2.19)

    self.inters = nn.ModuleList([
        make_inter_layer(curr_dim) for _ in range(nstack - 1)
    ])

    self.inters_ = nn.ModuleList([
        nn.Sequential(
            nn.Conv2d(curr_dim, curr_dim, (1, 1), bias=False),
            nn.BatchNorm2d(curr_dim)
        ) for _ in range(nstack - 1)
    ])
    self.cnvs_   = nn.ModuleList([
        nn.Sequential(
            nn.Conv2d(cnv_dim, curr_dim, (1, 1), bias=False),
            nn.BatchNorm2d(curr_dim)
        ) for _ in range(nstack - 1)
    ])

    self.t_regrs = nn.ModuleList([
        make_regr_layer(cnv_dim, curr_dim, 2) for _ in range(nstack)
    ])
    self.l_regrs = nn.ModuleList([
        make_regr_layer(cnv_dim, curr_dim, 2) for _ in range(nstack)
    ])
    self.b_regrs = nn.ModuleList([
        make_regr_layer(cnv_dim, curr_dim, 2) for _ in range(nstack)
    ])
    self.r_regrs = nn.ModuleList([
        make_regr_layer(cnv_dim, curr_dim, 2) for _ in range(nstack)
    ])

    self.relu = nn.ReLU(inplace=True)

def _train(self, *xs):
    image  = xs[0]
    t_inds = xs[1]
    l_inds = xs[2]
    b_inds = xs[3]
    r_inds = xs[4]

    inter = self.pre(image)
    outs  = []

    layers = zip(
        self.kps, self.cnvs,
        self.t_heats, self.l_heats, self.b_heats, self.r_heats,
        self.ct_heats,
        self.t_regrs, self.l_regrs, self.b_regrs, self.r_regrs,
    )
    for ind, layer in enumerate(layers):
        kp_, cnv_          = layer[0:2]
        t_heat_, l_heat_, b_heat_, r_heat_ = layer[2:6]
        ct_heat_                           = layer[6]
        t_regr_, l_regr_, b_regr_, r_regr_ = layer[7:11]

        kp  = kp_(inter)
        cnv = cnv_(kp)

        t_heat, l_heat = t_heat_(cnv), l_heat_(cnv)
        b_heat, r_heat = b_heat_(cnv), r_heat_(cnv)
        ct_heat        = ct_heat_(cnv)

        t_regr, l_regr = t_regr_(cnv), l_regr_(cnv)
        b_regr, r_regr = b_regr_(cnv), r_regr_(cnv)

        t_regr = _tranpose_and_gather_feat(t_regr, t_inds)
        l_regr = _tranpose_and_gather_feat(l_regr, l_inds)
        b_regr = _tranpose_and_gather_feat(b_regr, b_inds)
        r_regr = _tranpose_and_gather_feat(r_regr, r_inds)

        outs += [t_heat, l_heat, b_heat, r_heat, ct_heat, \
                 t_regr, l_regr, b_regr, r_regr]

        if ind < self.nstack - 1:
            inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
            inter = self.relu(inter)
            inter = self.inters[ind](inter)
    # print("+++++++++++++++outs shape:", outs[0].shape, outs[1].shape, outs[2].shape,outs[3].shape,outs[4].shape,outs[5].shape,outs[6].shape,outs[7].shape,outs[8].shape,)
    # print("+++++++++++++++outs shape:", outs[9].shape, outs[10].shape, outs[11].shape,outs[12].shape,outs[13].shape,outs[14].shape,outs[15].shape,outs[16].shape,outs[17].shape,)
    return outs

# @torch.jit.script
def _test(self, *xs, **kwargs):
    image = xs[0]

    inter = self.pre(image)
    outs  = []

    layers = zip(
        self.kps, self.cnvs,
        self.t_heats, self.l_heats, self.b_heats, self.r_heats,
        self.ct_heats,
        self.t_regrs, self.l_regrs, self.b_regrs, self.r_regrs,
    )
    for ind, layer in enumerate(layers):
        kp_, cnv_                          = layer[0:2]
        t_heat_, l_heat_, b_heat_, r_heat_ = layer[2:6]
        ct_heat_                           = layer[6]
        t_regr_, l_regr_, b_regr_, r_regr_ = layer[7:11]

        kp  = kp_(inter)
        cnv = cnv_(kp)

        if ind == self.nstack - 1:
            t_heat, l_heat = t_heat_(cnv), l_heat_(cnv)
            b_heat, r_heat = b_heat_(cnv), r_heat_(cnv)
            ct_heat        = ct_heat_(cnv)

            t_regr, l_regr = t_regr_(cnv), l_regr_(cnv)
            b_regr, r_regr = b_regr_(cnv), r_regr_(cnv)

            outs += [t_heat, l_heat, b_heat, r_heat, ct_heat,
                     t_regr, l_regr, b_regr, r_regr]

        if ind < self.nstack - 1:
            inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
            inter = self.relu(inter)
            inter = self.inters[ind](inter)
    # if kwargs['debug']:
    #     _debug(image, t_heat, l_heat, b_heat, r_heat, ct_heat)
    # del kwargs['debug']
    # print("output shape: ", self._decode(*outs[-9:], **kwargs).shape)
    return self._decode(*outs[-9:], **kwargs)
    # return outs[-9:]

def forward(self, *xs, **kwargs):
    if len(xs) > 1:
        return self._train(*xs, **kwargs)
    return self._test(*xs, **kwargs)