Pathological loss values when model reloaded

Hi!
I have trained a Resnet20 model and record the train and test losses epoch per epoch. Here is the plot
36
The fact that the test loss has a hieratic behaviour is an other subject (see Resnet: problem with test loss for details).

Schematically, during the training phase of the model, after all the dataloading, minimzer/scheduler init, for each epoch:

 for epoch in range(start_epoch, args.epochs + 1):
       train_loss = train(args, model, device, train_loader, train_transforms, optimizer, epoch)
       test_loss = test(args, model, device, test_loader, test_transforms)
       # save model ...
       state = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
                }
       torch.save(state,"model.pth")

In the train() function:

def train(args, model, device, train_loader, transforms, optimizer, epoch, attack=None, **attack_args):

    # switch network layers to Training mode
    model.train()

    train_loss = 0 # to get the mean loss over the dataset (JEC 15/11/19)
    # scans the batches
    for i_batch, sample_batched in enumerate(train_loader):
        img_batch = sample_batched['image']   #input image
        ebv_batch = sample_batched['ebv']     #input redenning
        z_batch   = sample_batched['z']       #target

        new_img_batch = torch.zeros_like(img_batch).permute(0, 3, 1, 2)

        #for CrossEntropyLoss no hot-vector
        new_z_batch = torch.zeros(batch_size,dtype=torch.long)

        for i in range(batch_size):
            # transform the images
            img = img_batch[i].numpy()
            for it in range(transf_size):
                img = transforms[it](img)
            new_img_batch[i] = img

            # transform the redshift in bin number
            z = (z_batch[i] - z_min) / (z_max - z_min)    # z \in 0..1                => z est reel
            z = max(0, min(n_bins - 1, int(z * n_bins)))  # z \in {0,1,.., n_bins-1}  => z est entier
            new_z_batch[i] = z

        # send the inputs and target to the device
        new_img_batch, ebv_batch, new_z_batch = new_img_batch.to(device), \
                                            ebv_batch.to(device), \
                                            new_z_batch.to(device)

        # reset the gradiants
        optimizer.zero_grad()
        # Feedforward
        output = model(new_img_batch, ebv_batch)
        # the loss
        loss = F.cross_entropy(output,new_z_batch)
        train_loss += loss.item() * batch_size 
        # backprop to compute the gradients
        loss.backward()
        # perform an optimizer step to modify the weights
        optimizer.step()

    # return some stat
    return train_loss/len(train_loader.dataset)

For the test() function it is essentially the same code but with

  1. model.eval() switch and
  2. the use of with torch.no_grad(): which is above the loop on the batches

Now, I have setup an other program for debugging. The philosophy is the following, once the last “model.pth” checkpoint is loaded

  1. use the same training and testing samples used for the training job described above, and also use the same random seeds init, the same data augmentation schema also
  2. use in place of the train/test function, a single process function which sets the model.eval() and the with torch.no_grad(): and then loops on the batches to computes the mean losses

So, I would have expected that the model parameters (notzbly the Batch Norm param stats) would be frozen, such that I would recover the test and train losses values but, this is not the case:

Train mean loss over  781  samples =  4.95804432992288
Test mean loss over  781  samples =  4.958497584095075

Have you an idea for instance why the loss computed with the same training set used during the training job is around 5 while it was around 2.5 !!!

Does the model saved at each epoch after a train followed by a test, has lost the BatchNorm parameters and so after reloading the model the two sets are recongnized as fresh sets ???

Result from an experiment. I have trained the model wo performing any test phase, this is to avoid the model.eval() during the training and during the saving of the model. After 30 epoch I got a train loss of <~ 2.5.

But the result is the same: both train loss and test loss are close du 6 !!!

Could you try to create a (small) reproducible code snippet by removeing unnecessary utility functions, so that we could have a look at it?

Well, let’s try this is the debugging script which seems pathological as one can see the plots of the 22th Dec. 2019 of BatchNorm parameters after model reload (the title is in fact inappropriate as the BN seems not responsible).

import torch.optim as optim
import torch.nn.functional as F
import random
import types
import os

# the Network Model Classes
from pz_network import *
# the Utility Classes to manipulate the images
from pz_image_manip_utils import *


############## Process BATCH ################
def process(model,sample_batched,transforms,train=None,debug=False):
    #We load the model checkpoint so no more training
    model.eval()

    n_bins = model.n_bins # the last number of neurons

    z_min = 0.0
    z_max = 1.0

    #    # turn off the computation of the gradient for all tensors
    with torch.no_grad():

        #get next data
        img_batch = sample_batched['image']   #input image
        ebv_batch = sample_batched['ebv']     #input redenning
        z_batch   = sample_batched['z']       #target

        batch_size  = len(img_batch)
        transf_size = len(transforms)

        new_img_batch = torch.zeros_like(img_batch).permute(0, 3, 1, 2)

        #for CrossEntropyLoss no hot-vector
        new_z_batch = torch.zeros(batch_size,dtype=torch.long)

        for i in range(batch_size):
            # transform the images
            img = img_batch[i].numpy()

            for it in range(transf_size):
                img = transforms[it](img)

            new_img_batch[i] = img

            # transform the redshift in bin number
            z = (z_batch[i] - z_min) / (z_max - z_min)    # z \in 0..1                => z est reel
            z = max(0, min(n_bins - 1, int(z * n_bins)))  # z \in {0,1,.., n_bins-1}  => z est entier
            new_z_batch[i] = z

        # send the inputs and target to the device
        new_img_batch, ebv_batch, new_z_batch = new_img_batch.to(device), \
                                                ebv_batch.to(device), \
                                                new_z_batch.to(device)


        # Feedforward
        output = model(new_img_batch, ebv_batch)
        print("ebv_batch: ",ebv_batch[:5])
        print("new_img_batch:\n",new_img_batch[:5])
        print("output:\n",output[:5])

        # the loss
        loss = F.cross_entropy(output,new_z_batch)

    return loss.item()


############## MAIN ################

use_cuda = torch.cuda.is_available()

device = torch.device("cuda" if use_cuda else "cpu")
print("Use device....: ",device)

# Pin_memory speed up the CPU->GPU transfert for NVidia GPU
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}

#model
model = resnet20()
# in place method to transfet to GPU
# according to https://pytorch.org/docs/stable/nn.html#torch.nn.Module.to
model.to(device)

#load model parameters
root_file = "/sps/lsst/users/campagne/torchphotoz/"
checkpoint_file ="robust-state-50_0_resnet20.pth.sav"

checkpoint_file = root_file + checkpoint_file
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])

print("Just after model.load_state_dict:")
for name, param in model.named_parameters():
    print('name: ', name)
    print('param.shape: ', param.shape)
    print(f"min/max/mean/std: ',{param.min().item():0.2e},{param.max().item():0.2e},{param.mean().item():0.2e},{param.std().item():0.2e}")
    print('=====')


#train batch
train_file = '/sps/lsst/data/campagne/sdss/train100k.npz'

#test batch
test_file  = '/sps/lsst/data/campagne/sdss/test100k.npz'

train_batch_size = 128
test_batch_size  = 128

# Train data
train_dataset = DatasetPz(train_file, transform=None)
        
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=train_batch_size,drop_last=True,
    shuffle=False, **kwargs)

trainloader_iterator = iter(train_loader)
print("train_loader length=",len(train_loader), " WARNING: NO SHUFFLE")

# Test data
test_dataset = DatasetPz(test_file, transform=None)
test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=test_batch_size, drop_last=True,
        shuffle=False, **kwargs)

testloader_iterator = iter(test_loader)
print("test_loader length=",len(test_loader))

# set manual seeds per epoch JEC 2/11/19 fix seed 0 once for all
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
random.seed(0)


#
#train_transforms = [RandomApplyPz([[flipH,flipV,identity],
#[rot90,rot180,rot270,identity]]),ToTensorPz()]

no_train_transforms = [ToTensorPz()]

test_transforms = [ToTensorPz()]

train_loss=0
train_num_batch = 10 if len(train_loader)>10 else len(train_loader)
for i in range(train_num_batch):
    print("train[",str(i),"]: no data augm")
    sample_batched = next(trainloader_iterator)
    train_loss += process(model,sample_batched,no_train_transforms,train=True)
print('Train mean loss over ',train_num_batch,' batches = ',train_loss/train_num_batch)

test_loss = 0
test_num_batch = 10 if len(test_loader)>10 else len(test_loader)
for i in range(test_num_batch):    
    print("test [",str(i),"]")
    sample_batched = next(testloader_iterator)
    test_loss += process(model,sample_batched,test_transforms,train=False)
print('Test mean loss over ',test_num_batch,' batches = ',test_loss/test_num_batch)

to load the data

class DatasetPz(Dataset):
    """ Load the data set which is supposed to be a Numpy structured array
        'z' : the true redshift array
        'ebv': reddening array
        'data': the images tensors  N H W C with C: nber of channels (ex. 5 filters)
    """

    def __init__(self, file_path, transform=None):
        self.data = np.load(file_path)  # load dataset into numpy array  N H W C
        #print('type de dataset data = ',self.data['data'].dtype)
        #transform into Float32
        self.z = self.data['z'].astype("float32")
        self.ebv = self.data['ebv'].astype("float32")
        self.imgs = self.data['data'].astype("float32")

        self.z_min = np.min(self.z)
        self.z_max = np.max(self.z)

        print("DatasetPz zmin: ",self.z_min," zmax: ",self.z_max)

        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        """
        Parameters
        ----------
        index:  index of the image, redshift, reddening

        Returns
        -------
        a dictionary with the image, redshift, reddening

        Apply a transform to the image if needed
        """
        image = self.imgs[index]
        if self.transform is not None:
            image = self.transform(image)

        return {'image': image, 'z': self.z[index], 'ebv': self.ebv[index]}

class ToTensorPz(object):
    """ Transform the tensor from HWC(TensorFlow default) to CHW (Torch)
    """
    def __call__(self, pic):
        # use copy here to avoid crash
        img = torch.from_numpy(pic.transpose((2, 0, 1)).copy())
        return img
    def __repr__(self):
        return self.__class__.__name__ + '()'

and the network that I have adapted for 1st conv input_channels and the concatenation of an extra infomation (ebv scalar) for the linear layer.

## Code from https://github.com/akamaster/pytorch_resnet_cifar10/blob/master/resnet.py
## '''
## Properly implemented ResNet-s for CIFAR10 as described in paper [1].
## The implementation and structure of this file is hugely influenced by [2]
## which is implemented for ImageNet and doesn't have option A for identity.
## Moreover, most of the implementations on the web is copy-paste from
## torchvision's resnet and has wrong number of params.
## Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
## number of layers and parameters:
## name      | layers | params
## ResNet20  |    20  | 0.27M
## ResNet32  |    32  | 0.46M
## ResNet44  |    44  | 0.66M
## ResNet56  |    56  | 0.85M
## ResNet110 |   110  |  1.7M
## ResNet1202|  1202  | 19.4m
## which this implementation indeed has.
## Reference:
## [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
##     Deep Residual Learning for Image Recognition. arXiv:1512.03385
## [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
## If you use this implementation in you work, please don't forget to mention the
## author, Yerlan Idelbayev.
## '''

def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlockV2(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A', h=1.0, track_run_stat=True, debug=False):
        super(BasicBlockV2, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes,track_running_stats=track_run_stat) # true is the default
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes,track_running_stats=track_run_stat)
        self.h = h
        self.debug = debug

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes, track_running_stats=False)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.debug:
            print('last bn2: ',self.bn2.weight,', ',
                  self.bn2.bias,', ',
                  self.bn2.running_mean,', ',
                  self.bn2.running_var)

        out = self.shortcut(x) + self.h * out  # Zhang et al. h=1 for the default Resnet
        out = F.relu(out)
        return out


class ResNetV2(nn.Module):
    def __init__(self, block, num_blocks, num_input_channels = 5, num_classes=10, h=1.0):
        super(ResNetV2, self).__init__()
        self.in_planes = 16

        ## JEC 4/12/2019
        self.num_input_channels = num_input_channels
        self.n_bins = num_classes


##        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv1 = nn.Conv2d(self.num_input_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16,track_running_stats=True) # True is the default
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, h=h, track_run_stat=True)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, h=h, track_run_stat=True)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, h=h, track_run_stat=True, debug=False)
##        self.linear = nn.Linear(64, num_classes)
        ## JEC add redenning variable
        self.linear = nn.Linear(64+1, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride, h=1.0, track_run_stat=True, debug=False):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, h=h, track_run_stat=track_run_stat, debug=debug))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x, reddening):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = torch.cat((out,reddening),dim=1)  ## add the redenning
        out = self.linear(out)
        return out

def resnet20(h=0.1):
    return ResNetV2(BasicBlockV2, [3, 3, 3], num_input_channels = 5, num_classes=180, h=h)

Your model seems to work fine and creates the same output after restoring:

# Setup
model = resnet20()
data = torch.randn(5, 5, 224, 224)
target = torch.randint(0, 128, (5,))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Train for 10 epochs
for idx in range(10):
    optimizer.zero_grad()
    output = model(data, torch.zeros(data.size(0), 1))
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print('Epoch {}, loss {}'.format(idx, loss.item()))

# Create reference eval output
model.eval()
data = torch.ones(1, 5, 224, 224)
with torch.no_grad():
    output = model(data, torch.zeros(data.size(0), 1))
torch.save({
    'state_dict': model.state_dict(),
    'ref_output': output
    },'ref.pth')


# in new script
model = resnet20()
checkpoint = torch.load('ref.pth')
model.load_state_dict(checkpoint['state_dict'])
ref_output = checkpoint['ref_output']
model.eval()

# Get output for ones
data = torch.ones(1, 5, 224, 224)
with torch.no_grad():
    output = model(data, torch.zeros(data.size(0), 1))

# Compare
print('max abs error: ', (output - ref_output).abs().max())
> max abs error:  tensor(0.)

I cannot check the other code, since some methods are undefined or point to specific data paths.

@ptrblck
My “resnet20” is supposed to have in inputs as model(image,extra_scalar) with
images= Nx (5, 64, 64) and extra_scalar=Nx1
and the outputs are Nx180 (ie a vector of length 180) that ones then use with a F.softmax(out,dim=1) to get probabilities.

Well, to give you some real data and the model trained I do not know how to do it. 100k samples is about 7.7G…

Hi @ptrblck

I have tried to repeat your exercise which provide to me a very good dbg idea (save input/output of some reference samples and at the same time the model dictionary). To do so I have rewritten fresh codes. During the training session, after the training optimisation, I send to the model after model.eval() 10 references inputs. Then, I save the model parameters, the refernce inputs and the corresponding outputs:

        test_ref_out = test_ref(args, model, device, sample_test_ref,
                        test_transforms, epoch)

        #save model state and reference test in/out
        state = {
            'model_state_dict': model.state_dict(),
            **test_ref_out
            }

with the returned dictionary test_ref_outdefined as

{'ref_in': new_img_batch, 'ref_ebv': ebv_batch, 'ref_out': output}

Now, in the debugging session

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
random.seed(0)

use_cuda = torch.cuda.is_available()

device = torch.device("cuda" if use_cuda else "cpu")
print("Use device....: ",device)

# Pin_memory speed up the CPU->GPU transfert for NVidia GPU
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}


#model
model = resnet20()
# in place method to transfet to GPU
# according to https://pytorch.org/docs/stable/nn.html#torch.nn.Module.to
model.to(device)

#load model parameters
root_file = "./"
checkpoint_file ="dbg-resnet20-2.pth"

checkpoint_file = root_file + checkpoint_file
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])

#x-check model inference form reference samples
ref_input_img = checkpoint['ref_in']
ref_input_ebv = checkpoint['ref_ebv']
ref_output    = checkpoint['ref_out']
model.eval()
with torch.no_grad():
    output = model(ref_input_img,ref_input_ebv)
# Compare
diff = (output - ref_output).abs().max().item()
print('max abs error: ', diff)
if diff>0:
    print('ref_out\n',ref_output,'\n',F.softmax(ref_output,dim=1))
    print('new_out\n',output,'\n',F.softmax(output,dim=1))

#x-check if load the same data
test_ref_file  = 'test10k.npz'
# Reference test
test_ref_dataset = DatasetPz(test_ref_file, transform=None)
test_ref_loader = torch.utils.data.DataLoader(
    test_ref_dataset,
    batch_size=10, drop_last=True,
    shuffle=False, **kwargs)

print("test_ref_loader length=",len(test_ref_loader))
testref_loader_iterator = iter(test_ref_loader)
sample_test_ref = next(testref_loader_iterator) # once for all
img_ref = sample_test_ref['image'].transpose(3,1).transpose(3,2)  # input image

ebv_ref = sample_test_ref['ebv']    # input redenning
img_ref,ebv_ref = img_ref.to(device),ebv_ref.to(device)
# Compare
print('img: max abs error: ', (img_ref - ref_input_img).abs().max())
print('ebv: max abs error: ', (ebv_ref - ref_input_ebv).abs().max())

The result is not satisfactory for the output of the model after reloaded as it finds different values from the loaded reference inputs, while I have x-checked that the reference inputs corresponds to the same data that I can directly laod from file (2nd & 3rd comparisons)

Model output:

max abs error:  11.42324447631836

The reference values saved in the checkpoint file: (I also produce the F.softmax(out,dim=1) to get the probability distribution

ref_out
 tensor([[-1.2732, -0.4523,  0.3474,  ..., -2.2444, -1.8375, -1.7659],
        [11.1764, 12.5662, 11.4143,  ..., -2.0558, -0.6456, -2.7947],
        [ 3.6198,  4.7427,  4.8422,  ..., -1.8688, -1.0088, -1.7117],
        ...,
        [-1.1483, -1.3629, -1.4520,  ..., -2.2878, -2.3314, -2.1438],
        [ 1.8598,  2.7087,  2.9787,  ..., -1.2867, -0.7799, -0.9943],
        [-0.6644,  0.2698,  0.9862,  ..., -2.5374, -1.9054, -1.7161]],
       device='cuda:0') 
 tensor([[1.3692e-06, 3.1115e-06, 6.9227e-06,  ..., 5.1839e-07, 7.7875e-07,
         8.3651e-07],
        [2.8840e-02, 1.1577e-01, 3.6586e-02,  ..., 5.1681e-08, 2.1172e-07,
         2.4686e-08],
        [2.4761e-03, 7.6110e-03, 8.4070e-03,  ..., 1.0235e-05, 2.4187e-05,
         1.1977e-05],
        ...,
        [4.1218e-05, 3.3256e-05, 3.0422e-05,  ..., 1.3188e-05, 1.2626e-05,
         1.5231e-05],
        [1.9157e-03, 4.4774e-03, 5.8652e-03,  ..., 8.2377e-05, 1.3674e-04,
         1.1036e-04],
        [8.0213e-07, 2.0415e-06, 4.1790e-06,  ..., 1.2326e-07, 2.3189e-07,
         2.8021e-07]], device='cuda:0')

And now the model reloaded output

new_out
 tensor([[ 0.1373,  1.2026,  1.6492,  ..., -0.6598,  0.0265, -1.3969],
        [ 0.0962,  1.1430,  1.6046,  ..., -0.6021,  0.0197, -1.3420],
        [ 0.1854,  1.1745,  1.7196,  ..., -0.5364,  0.0581, -1.4119],
        ...,
        [ 0.2378,  1.2321,  1.7412,  ..., -0.5621,  0.0525, -1.4340],
        [ 0.1403,  1.1963,  1.6627,  ..., -0.6335,  0.0163, -1.3910],
        [ 0.0681,  1.1639,  1.6145,  ..., -0.7172,  0.0148, -1.3750]],
       device='cuda:0') 
 tensor([[0.0021, 0.0060, 0.0093,  ..., 0.0009, 0.0018, 0.0004],
        [0.0021, 0.0059, 0.0093,  ..., 0.0010, 0.0019, 0.0005],
        [0.0023, 0.0062, 0.0107,  ..., 0.0011, 0.0020, 0.0005],
        ...,
        [0.0024, 0.0066, 0.0109,  ..., 0.0011, 0.0020, 0.0005],
        [0.0020, 0.0059, 0.0094,  ..., 0.0009, 0.0018, 0.0004],
        [0.0018, 0.0054, 0.0084,  ..., 0.0008, 0.0017, 0.0004]],
       device='cuda:0')

So, why my model reloading does not get the same outputs as during training phase even if in both case I have used model.eval() ??? what about the random ganerator ?

I do not know if you have noticed that I systematically use at the beginning of my scripts

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
random.seed(0)

Your work flow seems to stick to my dummy example, which returns exactly the same output.
However in your case, you still see a huge difference, and I’m not sure at the moment where this might still come from.
Could you use torch.ones for all inputs instead of saving the data?
It should of course still work with the saved data, however there seem to be some mismatch somewhere in the code.

Definitively there is something strange
I have filled with 1 the reference samples

    img_test_ref = torch.ones(10,64,64,5) # NxHxWxC  (this will be transfomed NCHW in the test_ref() function
    ebv_test_ref = torch.ones(10,1)
    z_test_ref   = torch.ones(10,1)
    sample_test_ref = {'image':img_test_ref,'ebv':ebv_test_ref,'z':z_test_ref}

Then, after 2 epochs I have reloaded the model and the above 1-fllled tensors

max abs error:  12.013355255126953

The reference outputs:

ref_out
 tensor([[-1.3784, -0.3093,  0.4179,  ..., -0.5395, -1.2264, -0.2501],
        [-1.3784, -0.3093,  0.4179,  ..., -0.5395, -1.2264, -0.2501],
        [-1.3784, -0.3093,  0.4179,  ..., -0.5395, -1.2264, -0.2501],
        ...,
        [-1.3784, -0.3093,  0.4179,  ..., -0.5395, -1.2264, -0.2501],
        [-1.3784, -0.3093,  0.4179,  ..., -0.5395, -1.2264, -0.2501],
        [-1.3784, -0.3093,  0.4179,  ..., -0.5395, -1.2264, -0.2501]],
       device='cuda:0') 
 tensor([[5.3705e-06, 1.5643e-05, 3.2369e-05,  ..., 1.2427e-05, 6.2522e-06,
         1.6597e-05],
        [5.3705e-06, 1.5643e-05, 3.2369e-05,  ..., 1.2427e-05, 6.2522e-06,
         1.6597e-05],
        [5.3705e-06, 1.5643e-05, 3.2369e-05,  ..., 1.2427e-05, 6.2522e-06,
         1.6597e-05],
        ...,
        [5.3705e-06, 1.5643e-05, 3.2369e-05,  ..., 1.2427e-05, 6.2522e-06,
         1.6597e-05],
        [5.3705e-06, 1.5643e-05, 3.2369e-05,  ..., 1.2427e-05, 6.2522e-06,
         1.6597e-05],
        [5.3705e-06, 1.5643e-05, 3.2369e-05,  ..., 1.2427e-05, 6.2522e-06,
         1.6597e-05]], device='cuda:0')

While the new output reads:

new_out
 tensor([[ 0.4656, -3.6460, -1.1858,  ..., -2.8810, -1.0660, -5.4568],
        [ 0.4656, -3.6460, -1.1858,  ..., -2.8810, -1.0660, -5.4568],
        [ 0.4656, -3.6460, -1.1858,  ..., -2.8810, -1.0660, -5.4568],
        ...,
        [ 0.4656, -3.6460, -1.1858,  ..., -2.8810, -1.0660, -5.4568],
        [ 0.4656, -3.6460, -1.1858,  ..., -2.8810, -1.0660, -5.4568],
        [ 0.4656, -3.6460, -1.1858,  ..., -2.8810, -1.0660, -5.4568]],
       device='cuda:0') 
 tensor([[6.3085e-07, 1.0335e-08, 1.2099e-07,  ..., 2.2210e-08, 1.3639e-07,
         1.6899e-09],
        [6.3085e-07, 1.0335e-08, 1.2099e-07,  ..., 2.2210e-08, 1.3639e-07,
         1.6899e-09],
        [6.3085e-07, 1.0335e-08, 1.2099e-07,  ..., 2.2210e-08, 1.3639e-07,
         1.6899e-09],
        ...,
        [6.3085e-07, 1.0335e-08, 1.2099e-07,  ..., 2.2210e-08, 1.3639e-07,
         1.6899e-09],
        [6.3085e-07, 1.0335e-08, 1.2099e-07,  ..., 2.2210e-08, 1.3639e-07,
         1.6899e-09],
        [6.3085e-07, 1.0335e-08, 1.2099e-07,  ..., 2.2210e-08, 1.3639e-07,
         1.6899e-09]], device='cuda:0')

Dear @ptrblck
After the exercise with Resnet20 described above, I have run the same scripts with a pure convolutional network (no BN) .

Then, the result is conform of what we expect
that is to say after reloading the network it gives the same outputs for the same inputs.

max abs error:  0.0

So,

  1. we can conclude that the debugging code is working
  2. in the case of the Resnet20 network we should investigate what can produce the difference we see after reloading the model.

Thanks and happy x-mas.

I have found something but really do not understand :<)

  1. If I use at the beginning of training session (notice that the reference sample inputs are filled with ones):
torch.backends.cudnn.deterministic =True
torch.backends.cudnn.benchmark = False

and at the debugging session the same variable settings then

max abs error:  0.0

while if the debugging session starts with

torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

then

max abs error:  9.5367431640625e-07
  1. If now the training session is done with
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

the debugging session answer is independent of the torch.backends.cudnn.deterministic = False/True and
torch.backends.cudnn.benchmark = True/False values, and it yields

max abs error:  4.76837158203125e-07
  1. I have also used the torch.backends.cudnn.enabled flag in the debugging session (the two other flags are commented):
torch.backends.cudnn.enabled = True
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False

then

max abs error:  0.0

While if

torch.backends.cudnn.enabled = False

then

max abs error:  2.384185791015625e-06

So, it is clear that the torch.backends.cudnn flags matter to recover the results after model reloading when the model is at least of ResNet architecture, may be due to the BatchNorm layers. But, then does it means that the arithmetic floating precision is the origin of all my troubles with BN??? What then should be the correct settings to be used?

Just in case you want more infos on my system:

PyTorch version: 1.3.1
Is debug build: No
CUDA used to build PyTorch: 10.1.243

OS: CentOS Linux release 7.7.1908 (Core)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-39)
CMake version: version 2.8.12.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.105
GPU models and configuration: GPU 0: Tesla V100-PCIE-32GB
Nvidia driver version: 418.87.01
cuDNN version: /opt/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.5.0

Versions of relevant libraries:
[pip] numpy==1.17.4
[pip] numpydoc==0.9.1
[pip] torch==1.3.1
[pip] torchvision==0.4.2
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.0.15 py37ha843d7b_0
[conda] mkl_random 1.1.0 py37hd6b4f25_0
[conda] pytorch 1.3.1 py3.7_cuda10.1.243_cudnn7.6.3_0 pytorch
[conda] torchvision 0.4.2 py37_cu101 pytorch

cudnn.benchmark might yield non-deterministic issue, and I’ve apparently missed these calls in your original script.
These inaccuracies can be accumulated depending on the model depth and architecture.
One way to narrow down this issue would be to compare each layer output using the deterministic setup and your current one.
It would be interesting so see if there is a sudden increase in error, as the difference in the final outputs is large.

Thanks @ptrblck may I ask you to give me more inputs to conduct this test? Best.

You could use forward hooks as described here.
However, this might be a bit tedious, so if you could wait for one or two days, I can debug it and see what’s going on.

Btw. thanks for the debugging so far and the code snippets! :slight_smile:

Your the welcome @ptrblck, I will wait as I am not experienced in pytorch.
By the way if by any chance my problem can be valuable, it would be great. My original problem was to understand this one Resnet: problem with test loss. And during the debugging the new problem rise… That’s the life of research :slight_smile:
“See you soon”.

I’ve tried to reproduce the large difference in your outputs, but fail to do so.
I’m currently using your resnet20 implementation with two inputs in the forward pass and fixed inputs.

To further debug, could you post a full code to reproduce this issue, please?

hello @ptrblck, do you mean the code (let us say code 1) which shows the dependence up on the cudnn.benchmark,or the one (code 2) which exhibits the hieratic behaviour of the test loss compared to the train loss (this is the original plot on top of this thread)?
Thanks
PS: Meanwhile, I will rerun the code 1

Get same problem, and as guys’ advice, I have done these following experience.

  • using different dataset: I’m using mnist with Zerospadding(114) whose size is (256,256). And after reloading and do prediction it returns a high accuracy, so I think the reloading works fine in mnist dataset
  • training for a while and reload it in the same session: it gives me a high accuracy as well as the model before reloading.but in different session it doesn’t work anymore

So is it may datasets’ problem? It will be really complex. I am tring to check every layer output but I don’t know how to do it.
here’s my simple description for my scripts:

#-----------------------fix random seed---------------------------------#
args.seed = 5153
print("Random Seed: ", args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.gpus:
    # Sets the seed for generating random numbers on all GPUs.
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

#---------------------------training-------------------------------------#
model = MobileNet1()
model.load_state_dict(weight['state_dict'])
data = generator(data_path) # generator is a python generator
for epoch in range(epochs):
    for x, y in data:
        output = model(x)
        loss = criterion(output, y)

        optimizer.zero_grad()
        loss /= accumulate_step
        loss.backward()
        optimizer.step()
    scheduler.step()
    save_checkpoint(filepath=args.save,
                    filename='{}-epoch{}-val_loss{:.4f}.pth'.format(
                            args.model_name, epoch, val_loss),
                            state={'epoch': epoch , 'state_dict': 
                             model.state_dict(), 'best_prec1': best_test,
                            'optimizer': optimizer.state_dict()},
            )

I am using Mobilenetv1, here’s my script

import torch.nn as nn
import torch.nn.functional as F

class MobileNet1(nn.Module):
    def __init__(self,initial_channel,n_class):
        super(MobileNet1, self).__init__()
        self.class_num = n_class

        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),

                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
            )

        self.model = nn.Sequential(
            conv_bn(initial_channel, 32, 2),
            conv_dw(32, 64, 1),
            conv_dw(64, 128, 2),
            conv_dw(128, 128, 1),
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
            nn.AvgPool2d(8),
        )
        self.fc = nn.Linear(1024, self.class_num)
    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

Merry Christmas everyone!
Over the past weeks I struggled with a similar problem and invested a lot of nerves and coputation time into it. Things that I found out and that might help others:

  • If you calculate a batch-wise mean loss (e.g. MSE) be careful how to calculate the mean of all batches. If your batches are not all the same size (use drop_last in the DataLoader to make sure your last batch is as well), you need to weight your mean.
  • If your in eval() mode, be careful if you print any summarys with external libraries like torch-summary or torchsummary. There is / was a bug concerning the order of the summary and eval() statements, I believe it’s fixed now in torch-summary.

(I worked with Resnet18 and ResNeXt50)

If you are coming across this thread late (like I did) because you have high loss after loading a saved model, this might be the solution. TL;DR - save and load the optimizer as well