Resnet: problem with test loss

Dear guru,

I’m working with pytorch 1.3.0

Here is a typical plot of train/test losses behaviour as epoch increases.
train-test-loss-resnet

I’m not an expert but I have read several topics on similar problems. Well, let me explain what I’m doing.

First, I have used implementation given by https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py for resnet18 & resnet50, and by https://github.com/akamaster/pytorch_resnet_cifar10 for resnet32, resnet56. For all these nets I got the same kind of test-loss hieratic behaviour.

Second, my inputs are images 5x64x64, so I have adapted the first Convolutional Layer, and the output of the last Full-connected consist of 180 neurons. I have used either 64, 128, 256 batch sizes for the training, and 128 for the test: the same behaviour persists. I have also both used 300k or 100k images in input training (100k for the test): same behaviour persists too. The images are not of “standard” RGB photos: first, as you probably have already, , remarked there are 5 channels, second the pixel values can be negative (eg. spanning the range (-0.01, 500))

Third, I am aware of the model.train() statement for the training phase, as well as the model.eval() statement (coupled with the with torch.no_grad():slight_smile: for the testing phase. It is clear that if I do not use model.eval() during the test phase, the test loss is gently decrasing as the traing loss. But, this is not allowed, isn’t it?

I have tried several things after reading post concerning Batch Norm behaviour wo any success

  • I have used SGD, Adam (& SWATS)
  • I have tryied lr = 0.1 to lr= 1e-5
  • I have modified the BN momentum (default = 0.1) : 0.5 and 0.01

Now, I have managed to get nice results (ie; good training & testing losses) with a classical CNN (ie. wo any Batch normalization, & short-cuts) but I would like to study Resnet behaviour against adversarial attack. So, I would like to get Resnet fit my images :slight_smile:

Any idea ?
Thanks

1 Like

Among the different test I have done, one seems stupid but may reveals something strange. I have used as Test/Validation set the same samples done used for training except that I do not flip H or V nor rotate by 90deg,180deg or 270deg, here is the test & train losses:
test-train-loss-with-same-set

1 Like

Besides disabling data augmentation, do you still have the batchnorm layers in your model and call model.eval()?
Which loss curve do you get, when you pass the training set (with data augmentation) to the model after calling model.eval()?

Dear @ptrblck thanks for your interest. I wander if the test loss behaviour comes from the BN or if it is a pb with resnets model for my images…

During the training I use model.train(), during the testing I use model.eval() and the model is not change between the two phases that I perform at each epoch (same batch size= 128 for train & test)

When I switch OFF complety the transformations of the train set, and uses the same set for train & test, I got the same behviour:
58

And finaly, if I switch off the suffling and random transforms (flips & Rotations) of train set, and I use the same set for test, then I get:
52

Seems that the test loss is converging towards a value, but different from the train loss.

If you can isolate the different loss curves using the same data augmentation and the same data, I would think the difference might come from skewed batchnorm running estimates.
Could you set track_running_stats=False for all batchnorm layers and rerun the experiments?

Dear @ptrblck,
Below are a series of experiments with resnet20, batch_size=128 both for training and testing.

First, let consider: Same data for train and test, no data augmentation (ie. no random flip H/V, rotations 90,180,270), and BN track_running_stats=False.
40

Here the two losses are pretty the same after 3 epochs. Ok, now I will turn the train shuffling ON

here is the case where train set shuffling is On, but still no data augmentation, test set=train set, and still BN tracking stats is False
38

Now I will allow data augmentation for training

Now, s
(1) train = test set
(2) with data augmentation + shuffling for Train set,
(3) still no shuffling and no data augmentation for test set
(4) BN with track_running_stats=False

42

I will now relax test set to be different than train set although originating from the same big set of images

Last experiment:
(1) test set DIFFERENT from train set but originating both from a large data set
(2) with data augmentation + shuffling for Train set,
(3) still no shuffling and no data augmentation for test set
(4) BN with track_running_stats=False

11

So if I am correct the fact that the test loss is following the train loss is due to track_running_stats=False for the Batch Norm layers at least all ones. May be some BN layers are more concerned than others? Anyway is it correct to have this switch turned to False ???

Now,
Starting from last experiment, I turn to True the track_running_stats of the first Batch Norm layer (bn1), the other BN have track_running_stats=False

self.bn1 = nn.BatchNorm2d(16,track_running_stats=True)

def forward(self, x, reddening):
out = F.relu(self.bn1(self.conv1(x)))

37

I start to see a difference between the test loss and the train loss… more to come.

Here a new experiment, following last one

out = F.relu(self.bn1(self.conv1(x)))    <- track_running_stats = TRUE
out = self.layer1(out)                   <- track_running_stats = TRUE
out = self.layer2(out)                   <- track_running_stats = FALSE
out = self.layer3(out)                   <- track_running_stats = FALSE

01

Now,

out = F.relu(self.bn1(self.conv1(x)))    <- track_running_stats = TRUE
out = self.layer1(out)                   <- track_running_stats = TRUE
out = self.layer2(out)                   <- track_running_stats = TRUE
out = self.layer3(out)                   <- track_running_stats = FALSE

20

So now the last experiment…soon. Stay tuned :slight_smile:

And the culprit is …

out = F.relu(self.bn1(self.conv1(x)))    <- track_running_stats = TRUE
out = self.layer1(out)                   <- track_running_stats = TRUE
out = self.layer2(out)                   <- track_running_stats = TRUE
out = self.layer3(out)                   <- track_running_stats = TRUE
out = F.avg_pool2d(out, out.size()[3])
out = out.view(out.size(0), -1
 out = self.linear(out)

48

So, the tracking of stat of the last layer of the resnet (sorry iy is written BasicBlock but the layers are defined as:

    self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, track_run_stat=True)
    self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, track_run_stat=True)
    self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, track_run_stat=True)

and the first BN

self.bn1 = nn.BatchNorm2d(16,track_running_stats=True)

So far so good, what I should conclude Dr. Watson?

Thanks for the experiments.
It still looks like the running estimates are off.
Are you shifting or scaling the data in your augmentation pipeline somehow?
Usually you see bad running estimates, if the training data comes from another domain or was preprocessed in another way (e.g. normalized during training, raw data during validation).

I would consider setting track_running_stats=False not the best workaround, as one might argue you are leaking information from the test set as well as you might be dependent on the batch size (the bigger the batch during testing, the better the estimate).

Dear @ptrblck,
For data augmentation I’m only consider: flip Horizontal/Vertical et Rotation 90deg, 180deg, 270 deg, that is to say no pixel value scaling neither shifting. The batch_size is 128 for both test and train phases.

The Test set and Train set 100k each are originating from the same larger data set 600k and I have checked the statistics of both sets.

I do not know what I can do now. The problem comes from the last layer of the resnet model…

Since you’ve isolated a possible reason for the difference in losses to the last batchnorm layer, it would be interesting to see the stats (min, max, mean, std) of the output of this particular bn layer using 1) a training batch 2) swithcing to eval() and using a validation batch.

@ptrblck
What do you mean exactly. Does it means:

load a pretrained model
load a train data set batch
use model.train()
during model(input) insert a print like that

def forward(self, x, reddening):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    print(out.min(),out.max(),out.mean(),out.std())     <-------------
    out = F.avg_pool2d(out, out.size()[3])
    out = self.linear(out)
    return out

And then switch to model.eval() and do the same with a test set batch.

It is the right protocol?

So, I have written a dedicated prg to load the train & test sets used to to train and test the ResNet20 model where all the BN s where tracking stats as mentioned in my post of 10th December. Then I load the trained model and alternatively process a batch of training samples setting model.train() an a batch of testing samples setting model.eval(). Notice that here I do not use (with torch.no_grad()) as I am using the same code for both processing mode, I just switch the model.train/eval at the beginning of the processing.

So far so good, here are the results which I should say puzzled me:

Use device…: cuda
train_loader length= 781
test_loader length= 781
batch size: 128
train[ 0 ]
out layer3: tensor(0., device=‘cuda:0’, grad_fn=) tensor(298.8851, device=‘cuda:0’, grad_fn=) tensor(0.2251, device=‘cuda:0’, grad_fn=) tensor(0.8224, device=‘cuda:0’, grad_fn=)
Train : loss = 6.8468918800354
test [ 0 ]
out layer3: tensor(0., device=‘cuda:0’, grad_fn=) tensor(196.5387, device=‘cuda:0’, grad_fn=) tensor(0.2451, device=‘cuda:0’, grad_fn=) tensor(0.8113, device=‘cuda:0’, grad_fn=)
Test : loss = 7.388835430145264
train[ 1 ]
out layer3: tensor(0., device=‘cuda:0’, grad_fn=) tensor(532.2750, device=‘cuda:0’, grad_fn=) tensor(0.2222, device=‘cuda:0’, grad_fn=) tensor(0.7947, device=‘cuda:0’, grad_fn=)
Train : loss = 7.686824798583984
test [ 1 ]
out layer3: tensor(0., device=‘cuda:0’, grad_fn=) tensor(235.0206, device=‘cuda:0’, grad_fn=) tensor(0.2438, device=‘cuda:0’, grad_fn=) tensor(0.7707, device=‘cuda:0’, grad_fn=)
Test : loss = 7.462079048156738
train[ 2 ]
out layer3: tensor(0., device=‘cuda:0’, grad_fn=) tensor(148.6887, device=‘cuda:0’, grad_fn=) tensor(0.2305, device=‘cuda:0’, grad_fn=) tensor(0.8213, device=‘cuda:0’, grad_fn=)
Train : loss = 7.347263813018799
test [ 2 ]
out layer3: tensor(0., device=‘cuda:0’, grad_fn=) tensor(78.3936, device=‘cuda:0’, grad_fn=) tensor(0.2437, device=‘cuda:0’, grad_fn=) tensor(0.6861, device=‘cuda:0’, grad_fn=)
Test : loss = 7.019978046417236
train[ 3 ]
out layer3: tensor(0., device=‘cuda:0’, grad_fn=) tensor(386.5198, device=‘cuda:0’, grad_fn=) tensor(0.2214, device=‘cuda:0’, grad_fn=) tensor(0.8807, device=‘cuda:0’, grad_fn=)
Train : loss = 6.6856513023376465
test [ 3 ]
out layer3: tensor(0., device=‘cuda:0’, grad_fn=) tensor(124.2858, device=‘cuda:0’, grad_fn=) tensor(0.2424, device=‘cuda:0’, grad_fn=) tensor(0.7225, device=‘cuda:0’, grad_fn=)
Test : loss = 7.320082187652588
train[ 4 ]
out layer3: tensor(0., device=‘cuda:0’, grad_fn=) tensor(289.6568, device=‘cuda:0’, grad_fn=) tensor(0.2118, device=‘cuda:0’, grad_fn=) tensor(0.7750, device=‘cuda:0’, grad_fn=)
Train : loss = 6.753640651702881
test [ 4 ]
out layer3: tensor(0., device=‘cuda:0’, grad_fn=) tensor(118.0012, device=‘cuda:0’, grad_fn=) tensor(0.2386, device=‘cuda:0’, grad_fn=) tensor(0.6660, device=‘cuda:0’, grad_fn=)
Test : loss = 6.969400405883789

Well, what is puzzling me is that I was expecting a much lower loss for the “train batches” then the “test batches”, and I far as I can see there are no big difference between the mode (model.train()) and (mode.test()). May be I am wrong as I should have performed a regular training with optimization, gradient descent and so on and print the stats of the output of the BN layers ???

Hi, Best wishes for 2020!

I have set-up a series of script to investigate the (pathologiocal) behaviour of the test loss when training a Resnet model.

As I cannot attach some files it will be a bit tedious, sorry.

(pz_synth_data.py) script to generate synthesized (rubish) data
#This is a script to produce synthetized input data
import numpy as np
from scipy.integrate import quad
from tqdm import tqdm

#####################
class ImgSynthe(object):
    """ Synthetisation d'images : C x H x W """
    Nfilters    = 5
    lambda_min0 = 0.05
    lambda_max0 = 1.0
    noise       = 0.01
    sig0        = 0.2
    strength0   = 1.0
    zmean       = 0.12
    zsig        = 0.035
    zmin        = 0.0
    zmax        = 0.3   
    N = 64
    x = (np.linspace(0,N-1,N)+0.5)/N
    y = (np.linspace(0,N-1,N)+0.5)/N
    xv, yv = np.meshgrid(x,y)
    
    def __init__(self):
        print("Perform synthetisation")
    
    def shape(self,x,y,norm=1.0,muX=0.,sigX=1.0,muY=0.0,sigY=1.0,rho=0.0):
        """ gaussian PSF """
        rho2 = 1.0-rho*rho
        sigX2 = sigX*sigX
        sigY2 = sigY*sigY
        return norm *\
            np.exp(-0.5/rho2*((x-muX)**2/sigX2+(y-muY)**2/sigY2)-2*rho*(x-muX)*(y-muY)/(sigX*sigY))

    def spec(self,x,norm=1.0,xmin=0.,xmax=1.):
        """ Simple emission spectrum """
        if x < xmin or x > xmax:
            return 0.0
        else:
            return 4.0*norm*(x-xmin)*(xmax-x)/(xmax-xmin)**2
    
    def intspec(self,a,b,norm,xmin,xmax):
        """ Integrale du spectre dans [a,b] """
        return quad(self.spec,a,b,args=(norm,xmin,xmax))[0]
    
    def __call__(self):
        # get z value
        z = self.zsig*np.random.randn()+self.zmean
        z = np.clip(z,self.zmin,self.zmax)
        # update spectral shape
        lambda_min = (1+z)*self.lambda_min0
        lambda_max = (1+z)*self.lambda_max0
        strength = self.strength0/(1+z)**2
        
        #update shape 
        sig = self.sig0/(1.+z)
        sigX = sig*np.random.uniform(0.1,1.0)
        sigY = sig*np.random.uniform(0.1,1.0)
        rho = np.random.uniform(0.,0.99)

        # signal in all filters
        filter_imgs = []
        for p in range(self.Nfilters):
            norm = self.intspec(p/5.,(p+1)/5.,strength,lambda_min,lambda_max)
            #print('norm: ',norm)
            img = self.shape(self.xv,self.yv,norm=norm,muX=0.5,sigX=sigX,muY=0.5,sigY=sigY,rho=rho)
            img += np.random.normal(0,self.noise,size=img.shape)
            filter_imgs.append(img)
        # transform to float32 to gain space
        imgs =  np.array(filter_imgs).astype("float32") # Nfilters x H x W
        # transpose for compatibility
        imgs = np.transpose(imgs,(1,2,0)) # H x W x Nfilters
        
        return { 'image': imgs, 'z':np.float32(z), 'ebv':np.float32(0.) } # ebv for compatibility

########################
def makeSet(gen,N,file="tmp.npz"):
    imgs = []
    zs   = []
    ebvs = []
    for i in tqdm(range(N),ascii=True,desc=file):
        data = gen()
        imgs.append(data['image'])
        zs.append(data['z'])
        ebvs.append(data['ebv'])
        
    zarr = np.array(zs)
    zarr = np.expand_dims(zarr,axis=-1)
    ebvarr = np.array(ebvs)
    ebvarr = np.expand_dims(ebvarr,axis=-1)
    np.savez(file,data=np.array(imgs),z=zarr,ebv=ebvarr)

################################
if __name__ == '__main__':
    gen = ImgSynthe()

    Ntrain = 128*10**3
    makeSet(gen,Ntrain,file="train_synth_128k.npz")
        
    Ntest  = 128*10**3
    makeSet(gen,Ntest,file="test_synth_128k.npz")

    Ntest_ref   = 10*10**3
    makeSet(gen,Ntest_ref,file="test_synth_10k.npz")

    print("All done. Bye")

pz_utils.py : Now a utility script to be used for the training, it deals with classes for the Data laoding and the Network design (Resnet and a simple ConvNet)

import random

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.init as init  #for ResNetV2

import numpy as np

########### A simple convnet  ########
class PzConv2d(nn.Module):
    """ Convolution 2D Layer followed by PReLU activation
    """
    def __init__(self, n_in_channels, n_out_channels, **kwargs):
        super(PzConv2d, self).__init__()
        self.conv = nn.Conv2d(n_in_channels, n_out_channels, bias=True,
                            **kwargs)
        nn.init.xavier_uniform_(self.conv.weight)
        nn.init.constant_(self.conv.bias,0.1)
        self.activ = nn.PReLU(num_parameters=n_out_channels, init=0.25)

    def forward(self, x):
        x = self.conv(x)
        return self.activ(x)


class PzPool2d(nn.Module):
    """ Average Pooling Layer
    """
    def __init__(self, kernel_size, stride, padding=0):
        super(PzPool2d, self).__init__()
        self.pool = nn.AvgPool2d(kernel_size=kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 ceil_mode=True,
                                 count_include_pad=False)

    def forward(self, x):
        return self.pool(x)


class PzFullyConnected(nn.Module):
    """ Dense or Fully Connected Layer followed by ReLU
    """
    def __init__(self, n_inputs, n_outputs, withrelu=True, **kwargs):
        super(PzFullyConnected, self).__init__()
        self.withrelu = withrelu
        self.linear = nn.Linear(n_inputs, n_outputs, bias=True)
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 0.1)
        self.activ = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        if self.withrelu:
            x = self.activ(x)
        return x
    
class NetCNNRed(nn.Module):
    def __init__(self,n_input_channels,debug=False):
        super(NetCNNRed, self).__init__()
        self.debug = debug
        # the number of bins to represent the output photo-z
        self.n_bins = 180

        self.conv0 = PzConv2d(n_in_channels=n_input_channels,
                              n_out_channels=64,
                              kernel_size=5,padding=2)
        self.pool0 = PzPool2d(kernel_size=2,stride=2,padding=0)

        self.conv1 = PzConv2d(n_in_channels=64,
                              n_out_channels=92,
                              kernel_size=3,padding=2)
        self.pool1 = PzPool2d(kernel_size=2,stride=2,padding=0)

        self.conv2 = PzConv2d(n_in_channels=92,
                              n_out_channels=128,
                              kernel_size=3,padding=2)
        self.pool2 = PzPool2d(kernel_size=2,stride=2,padding=0)


        self.fc0 = PzFullyConnected(n_inputs=12801,n_outputs=1024)
        self.fc1 = PzFullyConnected(n_inputs=1024,n_outputs=180)

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

    def forward(self, x, reddening):
        # x:image tenseur N_batch, Channels, Height, Width
        #    size N, Channles=5 filtres, H,W = 64 pixels
        # reddening: used

        #save original image
##        x_in = x

        if self.debug: print("input shape: ",x.size())

        # stage 0 conv 64 x 5x5
        x = self.conv0(x)
        if self.debug: print("conv0 shape: ",x.size())
        x = self.pool0(x)
        if self.debug: print("conv0p shape: ",x.size()) 

        # stage 1 conv 92 x 3x3
        x = self.conv1(x)
        if self.debug: print("conv1 shape: ",x.size())
        x = self.pool1(x)
        if self.debug: print("conv1p shape: ",x.size()) 

        # stage 2 conv 128 x 3x3
        x = self.conv2(x)
        if self.debug: print("conv2 shape: ",x.size())
        x = self.pool2(x)
        if self.debug: print("conv2p shape: ",x.size()) 


        if self.debug: print('>>>>>>> FC part :START <<<<<<<')
        flat   = x.view(-1,self.num_flat_features(x))
        concat = torch.cat((flat,reddening),dim=1)
        if self.debug: print('concat shape: ', concat.size())

        x = self.fc0(concat)
        if self.debug: print('fc0 shape: ',x.size())
        x = self.fc1(x)
        if self.debug: print('fc1 shape: ', x.size())

        output = x

        if self.debug: print('output shape: ',output.size())

##        params = {"output": output, "x": x_in, "reddening": reddening}

        #return params           
        return output


################### Resnet ##########

## Code adapeted from https://github.com/akamaster/pytorch_resnet_cifar10/blob/master/resnet.py

def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlockV2(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A', h=1.0, track_run_stat=True, debug=False):
        super(BasicBlockV2, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes,track_running_stats=track_run_stat) # true is the default
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes,track_running_stats=track_run_stat)
        self.h = h
        self.debug = debug

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes, track_running_stats=False)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.debug:
            print('last bn2: ',self.bn2.weight,', ',
                  self.bn2.bias,', ',
                  self.bn2.running_mean,', ',
                  self.bn2.running_var)

        out = self.shortcut(x) + self.h * out  # Zhang et al. h=1 for the default Resnet
        out = F.relu(out)
        return out


class ResNetV2(nn.Module):
    def __init__(self, block, num_blocks, num_input_channels = 5, num_classes=10, h=1.0):
        super(ResNetV2, self).__init__()
        self.in_planes = 16

        self.num_input_channels = num_input_channels
        self.n_bins = num_classes


##        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv1 = nn.Conv2d(self.num_input_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16,track_running_stats=True) # track_running_stats=True is the default
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, h=h, track_run_stat=True)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, h=h, track_run_stat=True)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, h=h, track_run_stat=True)
##        self.linear = nn.Linear(64, num_classes)
        ## JEC add redenning variable
        self.linear = nn.Linear(64+1, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride, h=1.0, track_run_stat=True, debug=False):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, h=h, track_run_stat=track_run_stat, debug=debug))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x, reddening):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = torch.cat((out,reddening),dim=1)  ## add the redenning
        out = self.linear(out)
        return out


def resnet20(h=0.1):
    return ResNetV2(BasicBlockV2, [3, 3, 3], num_input_channels = 5, num_classes=180, h=h)

## ################ Data load/augmentation #########

class DatasetPz(Dataset):
    """ Load the data set which is supposed to be a Numpy structured array
        'z' : the true redshift array
        'ebv': reddening array
        'data': the images tensors  N H W C with C: nber of channels (ex. 5 filters)
    """

    def __init__(self, file_path, transform=None):
        self.data = np.load(file_path)  # load dataset into numpy array  N H W C
        #print('type de dataset data = ',self.data['data'].dtype)
        #transform into Float32
        self.z = self.data['z'].astype("float32")
        self.ebv = self.data['ebv'].astype("float32")
        self.imgs = self.data['data'].astype("float32")

        self.z_min = np.min(self.z)
        self.z_max = np.max(self.z)

        print("DatasetPz zmin: ",self.z_min," zmax: ",self.z_max)

        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        """
        Parameters
        ----------
        index:  index of the image, redshift, reddening

        Returns
        -------
        a dictionary with the image, redshift, reddening

        Apply a transform to the image if needed
        """
        image = self.imgs[index]
        if self.transform is not None:
            image = self.transform(image)

        return {'image': image, 'z': self.z[index], 'ebv': self.ebv[index]}

class ToTensorPz(object):
    """ Transform the tensor from HWC(TensorFlow default) to CHW (Torch)
    """
    def __call__(self, pic):
        # use copy here to avoid crash
        img = torch.from_numpy(pic.transpose((2, 0, 1)).copy())
        return img
    def __repr__(self):
        return self.__class__.__name__ + '()'


class RandomApplyPz(transforms.RandomApply):
    """Apply randomly a list of transformations with a given probability

    Args:
        transforms (list or tuple): list of transformations
        p (float): list of probabilities
    """

    def __init__(self, transforms):
        super(RandomApplyPz, self).__init__(transforms)

    def __call__(self, img):
        # for each list of transforms
        # apply random sample to apply or not the transform
        for itset in range(len(self.transforms)):
            transf = self.transforms[itset]
            t = random.choice(transf)
            #### print('t:=',t)
            img = t(img)

        return img

def flipH(a):
    """
    Parameters
    ----------
    a: an image

    Returns
    -------
    an image flipped wrt the Horizontal axe
    """
    return np.flip(a,0)

def flipV(a):
    """
    Parameters
    ----------
    a: an image

    Returns
    -------
    an image flipped wrt the Vertical axe
    """
    return np.flip(a,1)

def rot90(a):
    """
    Parameters
    ----------
    a: an image

    Returns
    -------
    an image rotated 90deg anti-clockwise
    """
    return np.rot90(a,1)

def rot180(a):
    """
    Parameters
    ----------
    a: an image

    Returns
    -------
    an image rotated 180deg anti-clockwise
    """
    return np.rot90(a,2)

def rot270(a):
    """
    Parameters
    ----------
    a: an image

    Returns
    -------
    an image rotated 270deg anti-clockwise
    """
    return np.rot90(a,3)

def identity(a):
    """
    Parameters
    ----------
    a: an image

    Returns
    -------
    the same image
    """
    return a

continuation of previous post.

pz_train_resnet.py : the script for training
#This is a standalone script to train resnet for debugging
import argparse
import os
import random
import types

import torch
import torch.nn.functional as F

import numpy as np

from pz_utils import *

## ###################### TRAIN ##################
def train(args, model, device, train_loader, transforms, optimizer, epoch,
          weights_class=None):

    """
    Training phase 
    Minimize Cross Entropy Loss
    """

    assert transforms

    # switch network layers to Training mode
    model.train()

    n_bins = model.n_bins # the last number of neurons

    # JEC Todo : fix them both for training and testing
    z_min = 0.0
    z_max = 1.0


    train_loss = 0
    # scans the batches
    Nbatch = len(train_loader)

    for i_batch, sample_batched in enumerate(train_loader):
        img_batch = sample_batched['image']   #input image
        ebv_batch = sample_batched['ebv']     #input redenning
        z_batch   = sample_batched['z']       #target

        batch_size = len(img_batch)

        transf_size = len(transforms)
        new_img_batch = torch.zeros_like(img_batch).permute(0, 3, 1, 2)

        #for CrossEntropyLoss no hot-vector
        new_z_batch = torch.zeros(batch_size,dtype=torch.long)

        for i in range(batch_size):
            # transform the images
            img = img_batch[i].numpy()

            for it in range(transf_size):
                img = transforms[it](img)

            new_img_batch[i] = img

            # transform the redshift in bin number
            z = (z_batch[i] - z_min) / (z_max - z_min)    # z \in 0..1                => z est reel
            z = max(0, min(n_bins - 1, int(z * n_bins)))  # z \in {0,1,.., n_bins-1}  => z est entier
            new_z_batch[i] = z

        # send the inputs and target to the device
        new_img_batch, ebv_batch, new_z_batch = new_img_batch.to(device), \
                                            ebv_batch.to(device), \
                                            new_z_batch.to(device)


        # reset the gradiants
        optimizer.zero_grad()
        # Feedforward
        output = model(new_img_batch, ebv_batch)


        # the loss
        loss = F.cross_entropy(output,new_z_batch, weight=weights_class)
        train_loss += loss.item()
        # backprop to compute the gradients
        loss.backward()
        # perform an optimizer step to modify the weights
        optimizer.step()

        # some debug
        if i_batch % (Nbatch//10) == 0:
             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                 epoch, i_batch * batch_size, Nbatch*batch_size,
                 100. * i_batch / Nbatch, loss.item()))


    # return loss.item()
    return train_loss/Nbatch

## ################### TEST ###################

def test(args, model, device, test_loader, transforms, epoch,
         weights_class=None):
    """
    Testing phase.
    
    """
    # switch network layers to Testing mode
    model.eval()
    n_bins = model.n_bins # the last number of neurons
    # JEC Todo : fix them both for training and testing
    z_min = 0.0
    z_max = 1.0
    largeur_bin=(z_max-z_min)/n_bins

    test_loss = 0
    correct   = 0
    #    # turn off the computation of the gradient for all tensors
    Nbatch = len(test_loader)
    with torch.no_grad():
        # scans the batches
        for i_batch, sample_batched in enumerate(test_loader):
            img_batch = sample_batched['image']  # input image
            ebv_batch = sample_batched['ebv']    # input redenning
            z_batch   = sample_batched['z']      # target

            # transform the images
            batch_size = len(img_batch)
            transf_size = len(transforms)
            new_img_batch = torch.zeros_like(img_batch).permute(0, 3, 1, 2)
            # for CrossEntropyLoss no hot-vector needed
            new_z_batch = torch.zeros(batch_size, dtype=torch.long)

            for i in range(batch_size):
                # transform the images
                img = img_batch[i].numpy()

                for it in range(transf_size):
                    img = transforms[it](img)

                new_img_batch[i] = img

                # transform by hand to hot vector the redshift
                z = (z_batch[i] - z_min) / (z_max - z_min)    # z \in 0..1                => z est reel
                z = max(0, min(n_bins - 1, int(z * n_bins)))  # z \in {0,1,.., n_bins-1}  => z est entier
                new_z_batch[i] = z

            # send the inputs and target to the device
            new_img_batch, ebv_batch, new_z_batch = new_img_batch.to(device), \
                                                ebv_batch.to(device), \
                                                new_z_batch.to(device)


            # Feedforward
            output = model(new_img_batch, ebv_batch)


            # the loss (w/o the mean average)
            test_loss += F.cross_entropy(output, new_z_batch,
                                         weight=weights_class).item()
            #
            proba = F.softmax(output, dim=1)
            pred  = proba.argmax(dim=1)
            correct_current = sum([1 if y1==y2 else 0 for y1, y2 in zip(pred, new_z_batch)])
            correct += correct_current


    test_loss /= Nbatch
    accuracy = 100. * correct / (Nbatch*batch_size)

    return {'loss': test_loss, 'acc': accuracy}

## ################### Reference TEST to be saved ###################
def test_ref(args, model, device, sample_batched, transforms, epoch):
    """
    Testing  reference

    """
    # switch network layers to Testing mode
    model.eval()
    n_bins = model.n_bins # the last number of neurons
    # JEC Todo : fix them both for training and testing
    z_min = 0.0
    z_max = 1.0
    largeur_bin=(z_max-z_min)/n_bins


    with torch.no_grad():

        img_batch = sample_batched['image']  # input image
        ebv_batch = sample_batched['ebv']    # input redenning
        z_batch   = sample_batched['z']      # target

        # transform the images
        batch_size = len(img_batch)
        transf_size = len(transforms)
        new_img_batch = torch.zeros_like(img_batch).permute(0, 3, 1, 2)
        # for CrossEntropyLoss no hot-vector needed
        new_z_batch = torch.zeros(batch_size, dtype=torch.long)

        for i in range(batch_size):
            # transform the images
            img = img_batch[i].numpy()

            for it in range(transf_size):
                img = transforms[it](img)

            new_img_batch[i] = img

            # transform by hand to hot vector the redshift
            z = (z_batch[i] - z_min) / (z_max - z_min)    # z \in 0..1                => z est reel
            z = max(0, min(n_bins - 1, int(z * n_bins)))  # z \in {0,1,.., n_bins-1}  => z est entier
            new_z_batch[i] = z

        # send the inputs and target to the device
        new_img_batch, ebv_batch, new_z_batch = new_img_batch.to(device), \
                                            ebv_batch.to(device), \
                                            new_z_batch.to(device)


        # Feedforward
        output = model(new_img_batch, ebv_batch)


    # return the output reference
    return {'ref_in': new_img_batch, 'ref_ebv': ebv_batch, 'ref_out': output}


## ############################### MAIN ########################

def main():
    args = types.SimpleNamespace(
        fake_data = False,         # True means random inputs, False means either REAL or SYNTHISED data
        use_weighted_class = False,  # to reweight the class
        weights_file = './synth_weight.npy',
        network = "resnet20",      # alternative : "resnet20" or "convnet"
        deterministic = False,     #    torch.backends.cudnn.deterministic
        benchmark     = True,      #    torch.backends.cudnn.benchmark
        data_augmentation = True, # use or not data augmentation
        resume = False,            # resume a session
        checkpoint_file = "dbg-resnet20_cudnn_False_True_data_True_19.pth",
        history_loss_cpt_file = "dbg-history-resnet20_cudnn_False_True_data_True_19.npy",
        Nepochs         = 20, 
        batch_size      = 128,
        test_batch_size = 128,
        lr_init = 0.01,
        weight_decay = 0.0,
        momentum = 0.9,
        lr_decay = 0.1,
        no_cuda = False,
        root_file = "./",
        # data used if fake_data=False
# Real data
#        train_file     = './data/train100k.npz',
#        test_file      = './data/test100k.npz',
#        test_ref_file  = './data/test10k.npz'
# Synthetised data
        train_file     = './data/train_synth_128k.npz',
        test_file      = './data/test_synth_128k.npz',
        test_ref_file  = './data/test_synth_10k.npz'
    )
    
    args.checkpoint_file = args.root_file + args.checkpoint_file
    args.history_loss_cpt_file = args.root_file + args.history_loss_cpt_file


    print("\n### Training model ###")
    print("> Parameters:")
    for p, v in zip(args.__dict__.keys(), args.__dict__.values()):
        print('\t{}: {}'.format(p, v))

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")
    print("Use device....: ",device)

    # Pin_memory speed up the CPU->GPU transfert for NVidia GPU
    kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}


    weights_class = None # if reweighting is needed 
    
    if args.fake_data: # Use fake data train/test/test_ref
        ######################################        
        # Train data
        Ntrain = 10**3
        train_loader = []
        for i in range(Ntrain):
            img_train = torch.zeros(args.batch_size,64,64,5).uniform_(-1.,10.) # batch_sizexHxWxC
            ebv_train = torch.zeros(args.batch_size,1).exponential_(0.05)
            z_train   = torch.zeros(args.batch_size,1).normal_(0.1,0.03).clamp_(0,0.3)
            train_loader.append({'image':img_train,'ebv':ebv_train,'z':z_train})


        # Test data
        Ntest = 10**3
        test_loader = []
        for i in range(Ntest):
            img_test = torch.zeros(args.batch_size,64,64,5).uniform_(-1.,10.) # batch_sizexHxWxC
            ebv_test = torch.zeros(args.batch_size,1).exponential_(0.05)
            z_test   = torch.zeros(args.batch_size,1).normal_(0.1,0.03).clamp_(0,0.3)
            test_loader.append({'image':img_test,'ebv':ebv_test,'z':z_test})


        # Reference test
        test_ref_loader = []
        img_test_ref = torch.ones(10,64,64,5) # NxHxWxC
        ebv_test_ref = torch.ones(10,1)
        z_test_ref   = torch.ones(10,1)
        test_ref_loader.append({'image':img_test_ref,'ebv':ebv_test_ref,'z':z_test_ref})

    else:  # Use real/synthetized data train/test/test_ref
        ######################################
        if args.use_weighted_class :
            weights_class = np.load(args.weights_file)
            weights_class = torch.from_numpy(weights_class).to(device=device,dtype=torch.float)
        
        # Train data
        train_dataset = DatasetPz(args.train_file, transform=None)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,drop_last=True,
            shuffle=True, **kwargs)
        print("train_loader length=",len(train_loader))#, " WARNING: NO SHUFFLE")

        # Test data
        test_dataset = DatasetPz(args.test_file, transform=None)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.test_batch_size, drop_last=True,
            shuffle=False, **kwargs)
        print("test_loader length=",len(test_loader))
        # Reference data
        test_ref_dataset = DatasetPz(args.test_ref_file, transform=None)
        test_ref_loader = torch.utils.data.DataLoader(
            test_ref_dataset,
            batch_size=10, drop_last=True,
            shuffle=False, **kwargs)
        print("test_ref_loader length=",len(test_ref_loader))
        testref_loader_iterator = iter(test_ref_loader)
        sample_test_ref = next(testref_loader_iterator) # once for all



    # The Network
    if args.network == "resnet20":
        model = resnet20()
    else:
        args.network = "convnet"
        model = NetCNNRed(5)
    model.to(device)


    #Optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_init,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=False)

    #Scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                     factor=args.lr_decay, patience=5,
                                                     verbose=True)

    





    #Optional data augmentation for test set
    if args.data_augmentation:
        train_transforms = [RandomApplyPz([[flipH,flipV,identity],
                                           [rot90,rot180,rot270,identity]]),
                            ToTensorPz()]
    else:
        train_transforms = [ToTensorPz()]

    # No data augmentation for test set
    test_transforms = [ToTensorPz()]

    # Loop on the epochs
    train_loss_history = []
    test_loss_history  = []

    # load a previous session if required
    start_epoch = 0
    if args.resume :
        # load checkpoint of model/scheduler/optimizer
        if os.path.isfile(args.checkpoint_file):
            print("=> loading checkpoint '{}'".format(args.checkpoint_file))
            checkpoint = torch.load(args.checkpoint_file)
            # the first epoch for the new training
            start_epoch = checkpoint['epoch']
            # model update state
            model.load_state_dict(checkpoint['model_state_dict'])
            # scheduler update state
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            # optizimer update state
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            print("=> FATAL no  checkpoint '{}'".format(args.checkpoint_file))
            return

        #load previous history of losses
        if os.path.isfile(args.history_loss_cpt_file):
            loss_history = np.load(args.history_loss_cpt_file)
            train_loss_history = loss_history[0].tolist()
            test_loss_history  = loss_history[1].tolist()
        else:
            print("=> FATAL no history loss checkpoint '{}'".format(args.history_loss_cpt_file))
            return
    else:
        print("=> no checkpoints then Go as fresh start")


    #seeds manual init
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    random.seed(0)

    # Cuda Cudnn
    torch.backends.cudnn.deterministic = args.deterministic
    torch.backends.cudnn.benchmark = args.benchmark

    for epoch in range(start_epoch, start_epoch+args.Nepochs):

        print("process epoch[",epoch,"]: LR = ",end='')
        for param_group in optimizer.param_groups:
            print(param_group['lr'])

        # training
        train_loss = train(args, model, device, train_loader,
                           train_transforms, optimizer, epoch, weights_class)

        # test
        test_loss = test(args, model, device, test_loader,
                         test_transforms, epoch, weights_class)
        

        # bookkeeping
        train_loss_history.append(train_loss)
        test_loss_history.append(test_loss['loss'])

        print(f"Epoch {epoch}: Train loss {train_loss:.6f}, Test loss {test_loss['loss']:.6f}")

        #step of the scheduler
        scheduler.step(test_loss['loss'])

        # Ref test
        test_ref_out = test_ref(args, model, device,  sample_test_ref,
                                test_transforms, epoch)

        

        #save model state and reference test in/out
        state = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            **test_ref_out
            }

        torch.save(state,args.root_file+"/dbg-"+args.network+"_cudnn_"+str(args.deterministic)+"_"+str(args.benchmark)+"_data_"+str(args.data_augmentation)+"_"+str(epoch)+".pth")
        # save intermediate history
        np.save(args.root_file+"/dbg-history-"+args.network+"_cudnn_"+str(args.deterministic)+"_"+str(args.benchmark)+"_data_"+str(args.data_augmentation)+"_"+str(epoch)+".npy",
                np.array((train_loss_history,test_loss_history))
                )

    #loop on epochs done
    print('End of job!')


################################
if __name__ == '__main__':
  main()
pz_debug_bn.py: a script to look at loss after a relaoding of a checkpoint (infliuence of Cuda backend)
#This is a standalone script to train resnet for debugging

import argparse
import os
import random
import types

import numpy as np

from pz_utils import *

############## TRAIN BATCH ################
def process(model,sample_batched,transforms,train=None,debug=False):
    #We load the model checkpoint so no more training
    model.eval()

    n_bins = model.n_bins # the last number of neurons

    z_min = 0.0
    z_max = 1.0

    #    # turn off the computation of the gradient for all tensors
    with torch.no_grad():

        #get next data
        img_batch = sample_batched['image']   #input image
        ebv_batch = sample_batched['ebv']     #input redenning
        z_batch   = sample_batched['z']       #target

        batch_size  = len(img_batch)
        transf_size = len(transforms)

        new_img_batch = torch.zeros_like(img_batch).permute(0, 3, 1, 2)

        #for CrossEntropyLoss no hot-vector
        new_z_batch = torch.zeros(batch_size,dtype=torch.long)

        for i in range(batch_size):
            # transform the images
            img = img_batch[i].numpy()

            for it in range(transf_size):
                img = transforms[it](img)

            new_img_batch[i] = img

            # transform the redshift in bin number
            z = (z_batch[i] - z_min) / (z_max - z_min)    # z \in 0..1                => z est reel
            z = max(0, min(n_bins - 1, int(z * n_bins)))  # z \in {0,1,.., n_bins-1}  => z est entier
            new_z_batch[i] = z

        # send the inputs and target to the device
        new_img_batch, ebv_batch, new_z_batch = new_img_batch.to(device), \
                                                ebv_batch.to(device), \
                                                new_z_batch.to(device)


        # Feedforward
        output = model(new_img_batch, ebv_batch)
        print("ebv_batch: ",ebv_batch[:5])
        print("new_img_batch:\n",new_img_batch[:5])
        print("output:\n",output[:5])

        # the loss
        loss = F.cross_entropy(output,new_z_batch)

    return loss.item()



############## MAIN ################

#
# set manual seeds 
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
random.seed(0)

#Cuda 
#torch.backends.cudnn.enabled = False
#torch.backends.cudnn.deterministic = False
#torch.backends.cudnn.benchmark = True


use_cuda = torch.cuda.is_available()

device = torch.device("cuda" if use_cuda else "cpu")
print("Use device....: ",device)

# Pin_memory speed up the CPU->GPU transfert for NVidia GPU
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}


#model
model = resnet20()
#model = NetCNNRed(5)
model.to(device)

#load model parameters
root_file = "./"
checkpoint_file = "dbg-resnet20_cudnn_False_True_data_True_19.pth.SAV"


checkpoint_file = root_file + checkpoint_file
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])


#x-check model inference form reference point
ref_input_img = checkpoint['ref_in']
ref_input_ebv = checkpoint['ref_ebv']
ref_output    = checkpoint['ref_out']
model.eval()
with torch.no_grad():
    output = model(ref_input_img,ref_input_ebv)
# Compare
diff = (output - ref_output).abs().max().item()
print('max abs error: ', diff)
if diff>0:
    print('ref_out\n',ref_output,'\n',F.softmax(ref_output,dim=1))
    print('new_out\n',output,'\n',F.softmax(output,dim=1))

With all that,

  1. you can generate synthesized data
  2. train resnet20 (and a simpler convnet) in different conditions
    a) fixing Cuda backend
    b) using or not pure random data
    c) use unweighted classes or reweight the classes to counter-balance the class label distribution given by the ‘z’ varaible distribution binned in a hot-vector
  3. during the training every epoch I save a ‘pth’-file for the model/optimizer/scheduler… checkpoint and a ‘npy’-file to track the train/test loss. You can plot them with the following script
import numpy as np
import matplotlib.pyplot as plt
import argparse

def main():

    parser = argparse.ArgumentParser(description='loss-plot')
    parser.add_argument('--file', type=str, required=True)
    parser.add_argument('--tag', type=str, required=True)
    
    args = parser.parse_args()

    data = np.load(args.file,allow_pickle=True)
    
    train_loss = data[0]
    test_loss = data[1]
    plt.plot(train_loss, label="train")
    plt.plot(test_loss, label="test")
    plt.xlabel('epoch')
    plt.ylabel('Cross Entropy loss')
    plt.legend()
    plt.grid()
    plt.savefig('history_'+args.tag+'.png')
        
    plt.show()

if __name__ == '__main__':
  main()

To my numerical experiments: the test loss tends to be hieratic with the un-reweighted classes synthesized data but this is not the case for real data (ie. reweighting as no effect on test loss hieratic behaviour), certainly you will ask me questions before.