# Gradient w.r.t. Inputs using BatchNorm

Hello,

I’m trying to calculate the gradient of the output of a simple neural network with respect to the inputs. The result looks fine when I don’t use a BatchNorm layer. Once I do use it, the result doesn’t seem to make much sense. Below is a short example to reproduce the effect.

``````import torch
import torch.nn as nn
import matplotlib.pyplot as plt
``````
``````class Net(nn.Module):
def __init__(self, batch_norm):
super().__init__()

self.batch_norm = batch_norm
self.act_fn = nn.Tanh()

self.aff1 = nn.Linear(1, 10)
self.aff2 = nn.Linear(10, 1)

if batch_norm:
self.bn = nn.BatchNorm1d(10, affine=False)  # False for simplicity

def forward(self, x):
x = self.aff1(x)
x = self.act_fn(x)

if self.batch_norm:
x = self.bn(x)

x = self.aff2(x)
return x
``````
``````x_vals = torch.linspace(0, 1, 100)

fig, axs = plt.subplots(ncols=2, figsize=(16, 5))

for seed, bn, ax1 in zip([11, 7], [False, True], axs):  # different seeds for better illustration of effect
torch.manual_seed(seed)
net = Net(batch_norm=bn)

net.train()
pred = net(x_vals[:, None])

# visualization
ax2 = ax1.twinx()

ax1.plot(x_vals.detach(), pred.detach())
ax2.plot(x_vals.detach(), pred_dx.detach(), linestyle='--', color='orange')

min_idx = torch.argmin((pred[1:]-pred[:-1])**2)
ax2.axvline(x_vals[min_idx].detach(), color='gray', linestyle='dotted')
ax2.axhline(0, color='gray', linestyle='dotted')
ax1.set_title(('With' if bn else 'Without') + ' Batch Norm')

plt.show()
``````

The result also seems to be fine when I use evaluation mode. Unfortunately I can’t just switch to eval() mode because the nature of my problem (PINNs) requires calculating these gradients during training.

This question is probably related to the post The gradients of BatchNorm layer at mode of model.train() and model.eval() to which there are no answers.

I’m using python version 3.9.5, pytorch version 1.9.0+cu102.

I think BatchNorm1d might not .detach() the mean and variance when normalizing. When I use the implementation below it seems work (simplified by removing the extra bias and weight).

``````class BatchNorm(nn.Module):
def __init__(self, nFeatures, eps=1e-5, momentum=0.1):
super().__init__()
self.register_buffer("moving_avg", torch.zeros(nFeatures))
self.register_buffer("moving_var", torch.ones(nFeatures))
self.register_buffer("eps", torch.tensor(eps))
self.register_buffer("momentum", torch.tensor(momentum))

def forward(self, x):
if self.training:
mean = x.mean(dim=0)
var = x.var(dim=0)
self.moving_avg = self.moving_avg * self.momentum + mean * (1 - self.momentum)
self.moving_var = self.moving_var * self.momentum + var * (1 - self.momentum)

else:
mean = self.moving_avg
var = self.moving_var

# don't think original implementation uses .detach() on mean & var
x_norm = (x - mean.detach()) / (torch.sqrt(var.detach() + self.eps))
return x_norm
``````

I tried to find the exact implementation used by PyTorch to check. In torch.nn.functional the batch_norm function calls torch.batch_norm. Could somebody point me into the direction of where that’s implemented?

maybe somewhere here,