import torch
import torch.nn as nn
# Define the CNN model with BatchNorm2D
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(16, momentum=1)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
self.fc = nn.Linear(16 * 480 * 848, 2)
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.flatten(x)
x = self.fc(x)
return x
# Generate two constant random tensors
tensor1 = torch.randn(2, 3, 480, 848)
tensor2 = torch.randn(2, 3, 480, 848)
# Instantiate the model
model = SimpleCNN()
# Set up optimizer and loss function
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
# L2 regularization strength
l2_lambda = 0.01
# Concatenate tensors and adjust labels
inputs = torch.cat((tensor1, tensor2), dim=0)
labels = torch.tensor([0, 1, 0, 1]) # Adjusted labels for the batch
# Set the number of epochs
num_epochs = 10
# Training loop for multiple epochs
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
# Compute cross-entropy loss
loss = loss_fn(outputs, labels)
# L2 regularization term
l2_reg = torch.tensor(0.)
for param in model.parameters():
l2_reg += torch.norm(param, p=2) # L2 norm
# Compute the combined loss
total_loss = loss + l2_lambda * l2_reg
# Backward pass and optimization
total_loss.backward()
optimizer.step()
# Print running_mean and running_var
print(f"Epoch {epoch+1} - Running Mean:")
print(model.bn1.running_mean)
print(f"Epoch {epoch+1} - Running Variance:")
print(model.bn1.running_var)
with torch.no_grad():
# Validation loop
model.eval()
print(f"Epoch {epoch+1} Loss {total_loss}")
For example I have generated this dummy code the outputs are:
Epoch 1 - Running Mean:
tensor([-0.0828, 0.0310, -0.1817, -0.1875, 0.0570, -0.1008, 0.1894, 0.1808,
0.1177, 0.0498, 0.1221, 0.0795, 0.1806, 0.1176, -0.0843, 0.1756])
Epoch 1 - Running Variance:
tensor([0.3489, 0.2866, 0.3474, 0.3802, 0.2384, 0.3030, 0.2698, 0.4196, 0.3284,
0.3617, 0.4089, 0.3673, 0.4466, 0.3379, 0.3950, 0.3943])
Epoch 1 Loss 0.6932005882263184
Epoch 2 - Running Mean:
tensor([-0.0828, 0.0310, -0.1817, -0.1875, 0.0569, -0.1008, 0.1894, 0.1807,
0.1177, 0.0498, 0.1221, 0.0795, 0.1806, 0.1175, -0.0843, 0.1755])
Epoch 2 - Running Variance:
tensor([0.3489, 0.2866, 0.3474, 0.3802, 0.2384, 0.3030, 0.2698, 0.4196, 0.3284,
0.3616, 0.4089, 0.3673, 0.4465, 0.3379, 0.3950, 0.3942])
Epoch 2 Loss 0.11885640770196915
Epoch 3 - Running Mean:
tensor([-0.0827, 0.0310, -0.1817, -0.1875, 0.0569, -0.1008, 0.1894, 0.1807,
0.1177, 0.0497, 0.1220, 0.0795, 0.1806, 0.1175, -0.0843, 0.1755])
Epoch 3 - Running Variance:
tensor([0.3488, 0.2866, 0.3474, 0.3802, 0.2383, 0.3030, 0.2698, 0.4196, 0.3283,
0.3616, 0.4089, 0.3673, 0.4465, 0.3379, 0.3950, 0.3942])
Epoch 3 Loss 0.11885036528110504
Epoch 4 - Running Mean:
tensor([-0.0827, 0.0310, -0.1816, -0.1874, 0.0569, -0.1007, 0.1893, 0.1806,
0.1177, 0.0497, 0.1220, 0.0795, 0.1805, 0.1175, -0.0842, 0.1755])
Epoch 4 - Running Variance:
tensor([0.3488, 0.2866, 0.3473, 0.3801, 0.2383, 0.3030, 0.2698, 0.4195, 0.3283,
0.3616, 0.4088, 0.3673, 0.4465, 0.3378, 0.3949, 0.3942])
Epoch 4 Loss 0.11884436011314392
Epoch 5 - Running Mean:
tensor([-0.0827, 0.0310, -0.1816, -0.1874, 0.0569, -0.1007, 0.1893, 0.1806,
0.1176, 0.0497, 0.1220, 0.0795, 0.1805, 0.1175, -0.0842, 0.1754])
Epoch 5 - Running Variance:
tensor([0.3488, 0.2865, 0.3473, 0.3801, 0.2383, 0.3029, 0.2698, 0.4195, 0.3283,
0.3615, 0.4088, 0.3672, 0.4464, 0.3378, 0.3949, 0.3941])
Epoch 5 Loss 0.11883837729692459
Epoch 6 - Running Mean:
tensor([-0.0827, 0.0310, -0.1816, -0.1873, 0.0569, -0.1007, 0.1893, 0.1806,
0.1176, 0.0497, 0.1220, 0.0795, 0.1805, 0.1174, -0.0842, 0.1754])
Epoch 6 - Running Variance:
tensor([0.3487, 0.2865, 0.3473, 0.3801, 0.2383, 0.3029, 0.2697, 0.4195, 0.3283,
0.3615, 0.4088, 0.3672, 0.4464, 0.3378, 0.3949, 0.3941])
Epoch 6 Loss 0.11883237212896347
Epoch 7 - Running Mean:
tensor([-0.0827, 0.0310, -0.1815, -0.1873, 0.0569, -0.1007, 0.1892, 0.1805,
0.1176, 0.0497, 0.1220, 0.0794, 0.1804, 0.1174, -0.0842, 0.1754])
Epoch 7 - Running Variance:
tensor([0.3487, 0.2865, 0.3473, 0.3800, 0.2383, 0.3029, 0.2697, 0.4194, 0.3282,
0.3615, 0.4087, 0.3672, 0.4464, 0.3378, 0.3948, 0.3941])
Epoch 7 Loss 0.11882635205984116
Epoch 8 - Running Mean:
tensor([-0.0827, 0.0310, -0.1815, -0.1873, 0.0569, -0.1007, 0.1892, 0.1805,
0.1176, 0.0497, 0.1219, 0.0794, 0.1804, 0.1174, -0.0842, 0.1753])
Epoch 8 - Running Variance:
tensor([0.3487, 0.2865, 0.3472, 0.3800, 0.2382, 0.3029, 0.2697, 0.4194, 0.3282,
0.3615, 0.4087, 0.3671, 0.4463, 0.3377, 0.3948, 0.3940])
Epoch 8 Loss 0.11882033944129944
Epoch 9 - Running Mean:
tensor([-0.0826, 0.0310, -0.1815, -0.1872, 0.0569, -0.1006, 0.1892, 0.1805,
0.1176, 0.0497, 0.1219, 0.0794, 0.1804, 0.1174, -0.0842, 0.1753])
Epoch 9 - Running Variance:
tensor([0.3487, 0.2864, 0.3472, 0.3800, 0.2382, 0.3028, 0.2697, 0.4193, 0.3282,
0.3614, 0.4087, 0.3671, 0.4463, 0.3377, 0.3948, 0.3940])
Epoch 9 Loss 0.11881443113088608
Epoch 10 - Running Mean:
tensor([-0.0826, 0.0310, -0.1814, -0.1872, 0.0569, -0.1006, 0.1891, 0.1804,
0.1175, 0.0497, 0.1219, 0.0794, 0.1803, 0.1174, -0.0841, 0.1753])
Epoch 10 - Running Variance:
tensor([0.3486, 0.2864, 0.3472, 0.3799, 0.2382, 0.3028, 0.2696, 0.4193, 0.3282,
0.3614, 0.4086, 0.3671, 0.4462, 0.3377, 0.3947, 0.3940])
Epoch 10 Loss 0.11881034821271896```
You can see that the values slightly change. This being a small network it might not influence it as much but when talking about a huge one for example my HRNet I think that those small changes can accumulate and make the loss explode.