When I use the checkpoint
to warp a module with BatchNorm, will it change the running_average of BN?
I use a toy model to test the value
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch import optim
class bn_linear(nn.Module):
def __init__(self):
super(bn_linear, self).__init__()
self.bn = nn.BatchNorm1d(10, momentum=0.001)
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(self.bn(x))
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.s = bn_linear()
def forward(self, inp):
return checkpoint(self.s, inp)
# return self.s(inp)
model = Model()
for i in range(1000):
x = torch.ones(2,10)*i
x.requires_grad_(True)
loss = torch.mean(model(x))
print({n:p for n, p in model.named_buffers()})
But the result shows the same.