Pathological loss values when model reloaded

Well, let’s try this is the debugging script which seems pathological as one can see the plots of the 22th Dec. 2019 of BatchNorm parameters after model reload (the title is in fact inappropriate as the BN seems not responsible).

import torch.optim as optim
import torch.nn.functional as F
import random
import types
import os

# the Network Model Classes
from pz_network import *
# the Utility Classes to manipulate the images
from pz_image_manip_utils import *


############## Process BATCH ################
def process(model,sample_batched,transforms,train=None,debug=False):
    #We load the model checkpoint so no more training
    model.eval()

    n_bins = model.n_bins # the last number of neurons

    z_min = 0.0
    z_max = 1.0

    #    # turn off the computation of the gradient for all tensors
    with torch.no_grad():

        #get next data
        img_batch = sample_batched['image']   #input image
        ebv_batch = sample_batched['ebv']     #input redenning
        z_batch   = sample_batched['z']       #target

        batch_size  = len(img_batch)
        transf_size = len(transforms)

        new_img_batch = torch.zeros_like(img_batch).permute(0, 3, 1, 2)

        #for CrossEntropyLoss no hot-vector
        new_z_batch = torch.zeros(batch_size,dtype=torch.long)

        for i in range(batch_size):
            # transform the images
            img = img_batch[i].numpy()

            for it in range(transf_size):
                img = transforms[it](img)

            new_img_batch[i] = img

            # transform the redshift in bin number
            z = (z_batch[i] - z_min) / (z_max - z_min)    # z \in 0..1                => z est reel
            z = max(0, min(n_bins - 1, int(z * n_bins)))  # z \in {0,1,.., n_bins-1}  => z est entier
            new_z_batch[i] = z

        # send the inputs and target to the device
        new_img_batch, ebv_batch, new_z_batch = new_img_batch.to(device), \
                                                ebv_batch.to(device), \
                                                new_z_batch.to(device)


        # Feedforward
        output = model(new_img_batch, ebv_batch)
        print("ebv_batch: ",ebv_batch[:5])
        print("new_img_batch:\n",new_img_batch[:5])
        print("output:\n",output[:5])

        # the loss
        loss = F.cross_entropy(output,new_z_batch)

    return loss.item()


############## MAIN ################

use_cuda = torch.cuda.is_available()

device = torch.device("cuda" if use_cuda else "cpu")
print("Use device....: ",device)

# Pin_memory speed up the CPU->GPU transfert for NVidia GPU
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}

#model
model = resnet20()
# in place method to transfet to GPU
# according to https://pytorch.org/docs/stable/nn.html#torch.nn.Module.to
model.to(device)

#load model parameters
root_file = "/sps/lsst/users/campagne/torchphotoz/"
checkpoint_file ="robust-state-50_0_resnet20.pth.sav"

checkpoint_file = root_file + checkpoint_file
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])

print("Just after model.load_state_dict:")
for name, param in model.named_parameters():
    print('name: ', name)
    print('param.shape: ', param.shape)
    print(f"min/max/mean/std: ',{param.min().item():0.2e},{param.max().item():0.2e},{param.mean().item():0.2e},{param.std().item():0.2e}")
    print('=====')


#train batch
train_file = '/sps/lsst/data/campagne/sdss/train100k.npz'

#test batch
test_file  = '/sps/lsst/data/campagne/sdss/test100k.npz'

train_batch_size = 128
test_batch_size  = 128

# Train data
train_dataset = DatasetPz(train_file, transform=None)
        
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=train_batch_size,drop_last=True,
    shuffle=False, **kwargs)

trainloader_iterator = iter(train_loader)
print("train_loader length=",len(train_loader), " WARNING: NO SHUFFLE")

# Test data
test_dataset = DatasetPz(test_file, transform=None)
test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=test_batch_size, drop_last=True,
        shuffle=False, **kwargs)

testloader_iterator = iter(test_loader)
print("test_loader length=",len(test_loader))

# set manual seeds per epoch JEC 2/11/19 fix seed 0 once for all
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
random.seed(0)


#
#train_transforms = [RandomApplyPz([[flipH,flipV,identity],
#[rot90,rot180,rot270,identity]]),ToTensorPz()]

no_train_transforms = [ToTensorPz()]

test_transforms = [ToTensorPz()]

train_loss=0
train_num_batch = 10 if len(train_loader)>10 else len(train_loader)
for i in range(train_num_batch):
    print("train[",str(i),"]: no data augm")
    sample_batched = next(trainloader_iterator)
    train_loss += process(model,sample_batched,no_train_transforms,train=True)
print('Train mean loss over ',train_num_batch,' batches = ',train_loss/train_num_batch)

test_loss = 0
test_num_batch = 10 if len(test_loader)>10 else len(test_loader)
for i in range(test_num_batch):    
    print("test [",str(i),"]")
    sample_batched = next(testloader_iterator)
    test_loss += process(model,sample_batched,test_transforms,train=False)
print('Test mean loss over ',test_num_batch,' batches = ',test_loss/test_num_batch)

to load the data

class DatasetPz(Dataset):
    """ Load the data set which is supposed to be a Numpy structured array
        'z' : the true redshift array
        'ebv': reddening array
        'data': the images tensors  N H W C with C: nber of channels (ex. 5 filters)
    """

    def __init__(self, file_path, transform=None):
        self.data = np.load(file_path)  # load dataset into numpy array  N H W C
        #print('type de dataset data = ',self.data['data'].dtype)
        #transform into Float32
        self.z = self.data['z'].astype("float32")
        self.ebv = self.data['ebv'].astype("float32")
        self.imgs = self.data['data'].astype("float32")

        self.z_min = np.min(self.z)
        self.z_max = np.max(self.z)

        print("DatasetPz zmin: ",self.z_min," zmax: ",self.z_max)

        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        """
        Parameters
        ----------
        index:  index of the image, redshift, reddening

        Returns
        -------
        a dictionary with the image, redshift, reddening

        Apply a transform to the image if needed
        """
        image = self.imgs[index]
        if self.transform is not None:
            image = self.transform(image)

        return {'image': image, 'z': self.z[index], 'ebv': self.ebv[index]}

class ToTensorPz(object):
    """ Transform the tensor from HWC(TensorFlow default) to CHW (Torch)
    """
    def __call__(self, pic):
        # use copy here to avoid crash
        img = torch.from_numpy(pic.transpose((2, 0, 1)).copy())
        return img
    def __repr__(self):
        return self.__class__.__name__ + '()'

and the network that I have adapted for 1st conv input_channels and the concatenation of an extra infomation (ebv scalar) for the linear layer.

## Code from https://github.com/akamaster/pytorch_resnet_cifar10/blob/master/resnet.py
## '''
## Properly implemented ResNet-s for CIFAR10 as described in paper [1].
## The implementation and structure of this file is hugely influenced by [2]
## which is implemented for ImageNet and doesn't have option A for identity.
## Moreover, most of the implementations on the web is copy-paste from
## torchvision's resnet and has wrong number of params.
## Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
## number of layers and parameters:
## name      | layers | params
## ResNet20  |    20  | 0.27M
## ResNet32  |    32  | 0.46M
## ResNet44  |    44  | 0.66M
## ResNet56  |    56  | 0.85M
## ResNet110 |   110  |  1.7M
## ResNet1202|  1202  | 19.4m
## which this implementation indeed has.
## Reference:
## [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
##     Deep Residual Learning for Image Recognition. arXiv:1512.03385
## [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
## If you use this implementation in you work, please don't forget to mention the
## author, Yerlan Idelbayev.
## '''

def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlockV2(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A', h=1.0, track_run_stat=True, debug=False):
        super(BasicBlockV2, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes,track_running_stats=track_run_stat) # true is the default
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes,track_running_stats=track_run_stat)
        self.h = h
        self.debug = debug

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes, track_running_stats=False)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.debug:
            print('last bn2: ',self.bn2.weight,', ',
                  self.bn2.bias,', ',
                  self.bn2.running_mean,', ',
                  self.bn2.running_var)

        out = self.shortcut(x) + self.h * out  # Zhang et al. h=1 for the default Resnet
        out = F.relu(out)
        return out


class ResNetV2(nn.Module):
    def __init__(self, block, num_blocks, num_input_channels = 5, num_classes=10, h=1.0):
        super(ResNetV2, self).__init__()
        self.in_planes = 16

        ## JEC 4/12/2019
        self.num_input_channels = num_input_channels
        self.n_bins = num_classes


##        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv1 = nn.Conv2d(self.num_input_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16,track_running_stats=True) # True is the default
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, h=h, track_run_stat=True)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, h=h, track_run_stat=True)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, h=h, track_run_stat=True, debug=False)
##        self.linear = nn.Linear(64, num_classes)
        ## JEC add redenning variable
        self.linear = nn.Linear(64+1, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride, h=1.0, track_run_stat=True, debug=False):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, h=h, track_run_stat=track_run_stat, debug=debug))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x, reddening):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = torch.cat((out,reddening),dim=1)  ## add the redenning
        out = self.linear(out)
        return out

def resnet20(h=0.1):
    return ResNetV2(BasicBlockV2, [3, 3, 3], num_input_channels = 5, num_classes=180, h=h)