How to use have batch norm not forget batch statistics it just used?

I am in an unusual setting where I should not use running statistics (as that would be considered cheating e.g. meta-learning). However, I often run a forward pass on a set of points (5 in fact) and then I want to evaluate only on 1 point using the previous statistics but batch norm forgets the batch statistics it just uses. I’ve tried to hard code the value it should be but I get strange errors (even when I uncomment things like from the pytorch code itself like checking the dimension size).

How do I hardcode the previous batch statistics so that batch norm works on a new single data point and then reset them for a fresh new next batch?

note: I don’t want to change the batch norm layer type.

Sample code I tried:

def set_tracking_running_stats(model):
    for attr in dir(model):
        if 'bn' in attr:
            target_attr = getattr(model, attr)
            target_attr.track_running_stats = True
            target_attr.running_mean = torch.nn.Parameter(torch.zeros(target_attr.num_features, requires_grad=False))
            target_attr.running_var = torch.nn.Parameter(torch.ones(target_attr.num_features, requires_grad=False))
            target_attr.num_batches_tracked = torch.nn.Parameter(torch.tensor(0, dtype=torch.long), requires_grad=False)
            # target_attr.reset_running_stats()
    return

my most comment errors:

    raise ValueError('expected 2D or 3D input (got {}D input)'
ValueError: expected 2D or 3D input (got 1D input)

and

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

related: How does pytorch’s batch norm know if the forward pass its doing is for inference or training?

related: machine learning - When should one call .eval() and .train() when doing MAML with the PyTorch higher library? - Stack Overflow

Solution is to use mdl.train() it uses batch statistics by itself:

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1.

If track_running_stats is set to False, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well.

https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html