Confusion in prepare_qat method

Here’s a recipe from Practical Quantization in PyTorch | PyTorch

# QAT follows the same steps as PTQ, with the exception of the training loop before you actually convert the model to its quantized version

import torch
from torch import nn

backend = "fbgemm"  # running on a x86 CPU. Use "qnnpack" if running on ARM.

m = nn.Sequential(
     nn.Conv2d(2,64,8),
     nn.ReLU(),
     nn.Conv2d(64, 128, 8),
     nn.ReLU()
)

"""Fuse"""
torch.quantization.fuse_modules(m, ['0','1'], inplace=True) # fuse first Conv-ReLU pair
torch.quantization.fuse_modules(m, ['2','3'], inplace=True) # fuse second Conv-ReLU pair

"""Insert stubs"""
m = nn.Sequential(torch.quantization.QuantStub(), 
                  *m, 
                  torch.quantization.DeQuantStub())

"""Prepare"""
m.train()
m.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare_qat(m, inplace=True)

The prepare_qat method propogates qconfig to all the leaf nodes, runs convert which swaps modules and then calls the prepare method which add observers to the modules.

In the code above, quant and dequant stubs are added to the model. How are they not altered by the convert call in prepare_qat?

See below for prepare_qat implementation from pytorch/quantize.py at 748d9d24940cd17938df963456c90fa1a13f3932 · pytorch/pytorch · GitHub

def prepare_qat(model, mapping=None, inplace=False, allow_list=None):
    r"""
    Prepares a copy of the model for quantization calibration or
    quantization-aware training and converts it to quantized version.
    Quantization configuration should be assigned preemptively
    to individual submodules in `.qconfig` attribute.
    Args:
        model: input model to be modified in-place
        mapping: dictionary that maps float modules to quantized modules to be
                 replaced.
        inplace: carry out model transformations in-place, the original module
                 is mutated
        allow_list: a set that lists out allowable modules to be propagated with qconfig
    """
    torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
    if mapping is None:
        mapping = get_default_qat_module_mappings()

    if not inplace:
        model = copy.deepcopy(model)

    propagate_qconfig_(model, qconfig_dict=None, allow_list=allow_list)
    convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
    prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
    return model

Thank you!