Batchnorm downgrades performance with different batchsizes

Hi,
I have a problem with batchnorm(1d) there are some discussions on that but did’t helped
I train a network (with batchnorm1d) with batch_size 512. during eval and train it works fine with same number of batches but by changing batchsize from 512 to another value accuracy decreases

network:

class Model(nn.Module):
	def __init__(self, outdim, momentum, track=False):
		super(Model, self).__init__()
		self.inp = nn.Sequential(
              nn.BatchNorm1d(256, track_running_stats=track, momentum=momentum),
        )
		self.rnn1 = nn.GRU(7, 32, batch_first=True, bidirectional=True)
		self.attn1 = nn.MultiheadAttention(32*2, 8, dropout=.2)
		self.model = nn.Sequential(
			nn.BatchNorm1d(256*32*2, track_running_stats=track, momentum=momentum),
			nn.Linear(256*32*2, 64), nn.ReLU(),
			nn.BatchNorm1d(64, track_running_stats=track, momentum=momentum),
			nn.Linear(64, 128), nn.ReLU(),
			nn.BatchNorm1d(128, track_running_stats=track, momentum=momentum),
			nn.Linear(128, 128), nn.ReLU(),
		)
		self.out = nn.Sequential(
			nn.BatchNorm1d(128, track_running_stats=track, momentum=momentum),
			nn.Linear(128, outdim)
		)
		self.opt = AdamW(self.parameters(), .0001)
		self.loss = nn.CrossEntropyLoss()
		self.to(device)

	def forward(self, x):
		x = T.as_tensor(x).to(device).float()# x: (-1, 256, 7)
		x = self.inp(x)
		rnn1 = self.rnn1(x)[0]
		attn1 = self.attn1(*[rnn1.transpose(0,1)]*3)[0].transpose(0,1)
		model = self.model(attn1.flatten(1))
		return self.out(model)

results:
after doing following:

model2 = Model(2, 0)
model2.load_state_dict(model.state_dict())
with T.no_grad():
    model2.forward(x_train[:32])
    model2.forward(x_train[32:32*2])
    model2.forward(x_train[32*2:32*3])
    model2.forward(x_train[32*3:32*4])
model2.eval()
with T.no_grad():
    accs = []
    for j,(x,y) in enumerate(data.batch(x=x_train, y=y_train, batch_size=32)):
        output = model2(x).cpu().detach().numpy().argmax(-1).reshape(-1)
        accs.append((output == y.astype(int)).sum() / len(y))
        print(str(j/(len(y_train) / x.shape[0]))[:7])
        display.clear_output(1)
print(str(np.mean(accs))[:7])

Accuracy: 85.5%
and if I change batch_size from 512 to 32 results in:
Acuraccy: 47.5%

btw I also tried changing momentum in model2 but no effects