You can easily clone the sklearn behavior using this small script:

x = torch.randn(10, 5) * 10
scaler = StandardScaler()
arr_norm = scaler.fit_transform(x.numpy())
# PyTorch impl
m = x.mean(0, keepdim=True)
s = x.std(0, unbiased=False, keepdim=True)
x -= m
x /= s
torch.allclose(x, torch.from_numpy(arr_norm))

Alternatively, you could of course just use the sklearn scaler directly, as torch.numpy() and torch.from_numpy() return arrays which share the underlying data, and are thus very cheap.

I have an input that has required_grad=True. I need to scale it, and I wondered if the solution in this post would break the graph such that the gradient is not computable later?

FWIW I had implemented something similar before stumbling upon StandardScaler; I gave it a slightly worse name

class DatasetNorm1d(nn.Module):
"""Records dataset stats. kthxbye"""
# Similar to: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html
def __init__(self, num_features):
super().__init__()
self._num_features = num_features
self.register_buffer('mean', torch.full((self._num_features,), np.nan))
self.register_buffer('var', torch.full((self._num_features,), np.nan))
def _is_initialized(self):
if torch.isnan(self.mean).any() or torch.isnan(self.var).any():
return False
return True
@torch.no_grad()
def initialize(self, input_batch):
"""
Args:
input_batch: Batch that should represent *all* inputs for a given
dataset.
Example:
norm.initialize(torch.cat([x for (x, _) in train_loader]))
"""
# TODO(eric.cousineau): Use an accurate running computation?
# See: https://github.com/pytorch/pytorch/blob/480851ad/aten/src/ATen/native/Normalization.cpp#L215-L269
assert not self._is_initialized()
N, L = input_batch.shape
assert L == self._num_features
assert N > 1
var_mean = torch.var_mean(input_batch, dim=0)
self.var.data[:], self.mean.data[:] = var_mean
def forward(self, x):
if not self._is_initialized():
raise RuntimeError("This must be initialized on the dataset!")
y = (x - self.mean) / torch.sqrt(self.var)
return y