`torch.utils.checkpoint` with BatchNorm

When I use the checkpoint to warp a module with BatchNorm, will it change the running_average of BN?

I use a toy model to test the value

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch import optim

class bn_linear(nn.Module):
    def __init__(self):
        super(bn_linear, self).__init__()
        self.bn = nn.BatchNorm1d(10, momentum=0.001)
        self.linear = nn.Linear(10, 1)
    def forward(self, x):
        return self.linear(self.bn(x))

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.s = bn_linear()
    def forward(self, inp):
        return checkpoint(self.s, inp)
        # return self.s(inp)

model = Model()

for i in range(1000):
    x = torch.ones(2,10)*i
    x.requires_grad_(True)
    loss = torch.mean(model(x))

print({n:p for n, p in model.named_buffers()})

But the result shows the same.