This code snippet reproduces the issue for me on pytorch 1.1.0
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleNet(nn.Module):
def __init__(self, image_size_total):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(64)
self.max_pool1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear((image_size_total//4) * 64, 2)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.bn1(x)
x = self.max_pool1(x)
x = x.view(-1, self.num_flat_features(x))
x = self.fc1(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
width = 64
height = 64
network = SimpleNet(width* height)
batchone = torch.ones([4, 1, 64, 64], dtype=torch.float, device=torch.device("cuda:1"))
outputone = torch.tensor([.5, .5]).to(torch.device("cuda:1"))
batchtwo = torch.randn([4, 1, 64, 64], dtype=torch.float, device=torch.device("cuda:1"))
outputtwo = torch.tensor([.01, 1.0]).to(torch.device("cuda:1"))
def train_net(net, batch, output):
net.train()
optimizer = optim.SGD(net.parameters(), 0.0001)
criterion = nn.MSELoss()
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
netoutput = net(batch)
loss = criterion(netoutput, output)
loss.backward()
optimizer.step()
return float(loss)
def evaluate_batch(net, batch, output, shouldeval):
if shouldeval:
net.eval()
else:
net.train()
criterion = nn.MSELoss()
# forward + backward + optimize
netoutput = net(batch)
loss = criterion(netoutput, output)
return float(loss)
network.to(torch.device("cuda:1"))
for i in range(100):
print("t loss1:", train_net(network, batchone, outputone))
print("t loss2:", train_net(network, batchtwo, outputone))
print("v loss1:", evaluate_batch(network, batchone, outputone, True))
print("v loss2:", evaluate_batch(network, batchtwo, outputone, True))
print("train v loss1:", evaluate_batch(network, batchone, outputone, False))
print("train vv loss2:", evaluate_batch(network, batchtwo, outputone, False))
If i remove the batchnorm line it removes the discrepancy