The labels of my datasets (images) changes during training

I am working on a problem of spectral super-resolution where the inputs to the models are both rgb image with 3 channel (input image) and hyperspectral image with 31 channel (the labels to compare the output with).
At the training phase the pixel values of the labels changes without any reason.
Please any help regarding this issue.

Could you explain your use case a bit more and especially how the targets are used?
I would assume they are only passed to the loss function so I wouldn’t know how they can be changed.

Thanks a lot for your concern.
I am using a model to reconstruct hyperspectral image from RGB image. The idea is that; the input to the model is RGB image with Nx3xWxH and the output is hyperspectral image with dimension NxC=31xWxH image. The used labels is ground truth hyperspectral image with NxC=31xWxH. I only used the labels in two position when fetching the data and move it to the gpu and when computing L1 loss between the reconstructed and ground truth images. The problem that i am facing is that the label’s pixel values changes during training until it have nan values. I think some of the images is corrupted some how and i am checking it now. I need to know the following:
1- is it possible that labels can be changed any way.
2- the training is stable under condition that not switching the model mode to train or validate by model.train() or model.eval().
3- the training is fine as long as i am not separate the code to train function test function which is quite weird. Sorry for the long post
Thanks again a lot for your concern.

  1. Yes, since the labels are a tensor you can of course change a tensor. However, since they should be used in the loss calculation only, I wouldn’t know where they are changed without seeing code.

  2. I don’t see a question here. If the training is “unstable” when eval() is called, I assume you are seeing a higher loss value which you wouldn’t expect? If so, I would guess the some norm layers (e.g. batchnorm layers) are updating the internal stats with noisy batch stats updates (we have a lot of similar discussions about these issues).

  3. I also don’t see a question here, so you would need to describe the issue more.

Thanks a lot for your reply. Please i need to know the following: I did very interesting experiment (all the experiments done without any calculations of the loss or even any backpropagation);
1- i divided my training set to four even parts and train every part individually to see if the training is stable and the labels doesn’t change or not and the conclusion in this case is the train loop is stable.
2- i used the half of my training set with the same setting as mentioned above and the training loop is stable.
3- if i increase the trainset more than half; the labels changing even when there are no loss function or backpropagation.
4- if i make the model very simple the training is stable.
5- if the labels not moved to the cuda the training is stable.
I came to conclusion that the model and the data (labels) occupy more space that the gpu cannot handle, so there are memory leak or something, and that effect the labels not rgb images since the size of labels are very big >>25 gigs. Is that conclusion right??? if you need me to post any code just let me know.
Thanks a lot for your help and patience

No, I don’t think the conclusion is right, as a memory leak wouldn’t show up as a memory corruption.
Yes, a minimal, executable code snippet showing this behavior would be great to debug this issue.

Thanks for the reply.
The code for the simple model is as follow:

class Residual_Block(nn.Module):
    def __init__(self, Cn=64, ksize=3):
        super(Residual_Block, self).__init__()
        self.conv = self.make_layer(Conv_ReLU_Block, conv_num=1, cn=Cn)
        self.ouput = nn.Conv2d(in_channels=Cn, out_channels=Cn, kernel_size=ksize, stride=1, padding=int((ksize - 1) / 2), bias=False)
        self.relu = nn.PReLU()

    def make_layer(self, block, conv_num, cn):
        layer = []
        for _ in range(conv_num):
            layer.append(block(cn))
        return nn.Sequential(*layer)

    def forward(self, x):
        out = self.conv(x)
        out = self.ouput(out)
        out = out + x
        # out = self.relu(out)
        return out
class FMNet2(nn.Module):
    def __init__(self,in_channels=3,channels=128,out_channels=31):
        super(FMNet2, self).__init__()
        
        self.conv0 = nn.Conv2d(in_channels, channels, kernel_size=3,padding=1,stride=1, bias=False)
        self.res1 = Residual_Block(Cn=channels)
        self.res2 = Residual_Block(Cn=channels)
        self.res3 = Residual_Block(Cn=channels)
        self.res4 = Residual_Block(Cn=channels)
        self.res5 = Residual_Block(Cn=channels)
        self.res6 = Residual_Block(Cn=channels)
        self.res7 = Residual_Block(Cn=channels)
        self.res8 = Residual_Block(Cn=channels)
        self.res9 = Residual_Block(Cn=channels)
        self.conv1 = nn.Conv2d(channels, out_channels, kernel_size=3,padding=1,stride=1, bias=False)
    
    def forward(self,x):
        print(x.shape)
        out = self.conv0(x)
        print(out.shape)
        res = out
        out = self.res1(out)
        out = self.res2(out)
        out = self.res3(out)
        out = self.res4(out)
        out = self.res5(out)
        out = self.res6(out)
        out = self.res7(out)
        out = self.res8(out)
        out = self.res9(out)
        out = out + res
        out = self.conv1(out)
        return out

The code for the dataset class is as follow:

class HyperDataset(udata.Dataset):
    def __init__(self, mode='train'):
        self.mode = mode

        if self.mode == 'train':
            self.h5f = h5py.File('./Dataset/train_clean.h5', 'r')
        elif self.mode == 'test':
            self.h5f = h5py.File('./Dataset/test_final.h5', 'r')
        

        #self.keys = list(self.h5f.keys())
        if 'train' in self.mode:
            self.keys = list(self.h5f.keys())
            random.shuffle(self.keys)
            self.len = len(self.keys)
            
        else:
            self.keys = list(self.h5f.keys())
            self.keys.sort()
            self.len = len(self.keys)
            
       
    def __len__(self):
        #return len(self.keys)
         return self.len

    def __getitem__(self, index):
        key = str(self.keys[index])
        data = np.array(self.h5f[key])
        data = torch.Tensor(data)
        # the first part of the tuple is the rgb i mage and the second part is hyperspectral
        return data[31:34,:,:], data[0:31,:,:]
    def close(self):
        self.h5f.close()
    def shuffle(self):
        if 'train' in self.mode:
            random.shuffle(self.keys)

note that the data is cropped and saved in one big h5 file offline before the training
The simple training loop

train_dataset = dataset.HyperDataset(mode='train')
test_dataset  = dataset.HyperDataset(mode='test')
print("Train_dataset:%d" % (len(train_dataset)))
print("Validation set samples:", len(test_dataset))
# Data Loader (Input Pipeline)
dataloader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
val_loader   = DataLoader(dataset=test_dataset, batch_size=1,  shuffle=False, num_workers=0, pin_memory=True)
    
    
    
    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()
    logger2 = initialize_logger(log_dir2)
    # For loop training
    for epoch in range(opt.epochs):
        generator.train()
        total_loss =  utils.AverageMeter()
        for i, (img_A, img_B) in enumerate(dataloader):
            # this control statement is only to check if the data is changing or not
            # it is saved to the log file 
            if img_B.min()<0 or img_B.max()>1:
                print("yes there is problem in labels ")
                logger2.info(" IMAGBTRain Epoch [%02d],batch no: %d/%d"
                % (epoch,i+1,len(dataloader)))
            if img_A.min()<0 or img_A.max()>1:
                 print("yes there is problem in labels ")
                 logger2.info("IMAGATRAIN Epoch [%02d],batch no: %d/%d"
                 % (epoch,i+1,len(dataloader)))
            generator.zero_grad()
            optimizer_G.zero_grad()
            # To device
            img_A = img_A.cuda()
            img_B = img_B.cuda()

            # Train Generator

            # Forword propagation
            recon_B = generator(img_A)
            # # Losses
            loss = criterion_L1(recon_B, img_B)
            # # Overall Loss and optimize
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time))
            prev_time = time.time()
            total_loss.update(loss.data)

            # Print log
            print("\r[Epoch %d/%d] [Batch %d/%d] [Total Loss: %.4f] Time_left: %s" %
                ((epoch + 1), opt.epochs, i, len(dataloader), total_loss.avg, time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G)

       #--------------------------------------
       #                    Validation
       #--------------------------------------
        generator.eval() 
        losses = utils.AverageMeter()
        with torch.no_grad():
           for i, data in enumerate(val_loader):
             images,labels = data
            # this control statement is only to check if the data is changing or not
            # it is saved to the log file 
             if labels.min()<0 or labels.max()>1:
                 print("yes there is problem in labels ")
                 logger2.info(" IMAGBVAL Epoch [%02d],batch no: %d/%d"
                 % (epoch,i+1,len(dataloader)))
             if images.min()<0 or images.max()>1:
                  print("yes there is problem in labels ")
                  logger2.info("IMAGAVAL Epoch [%02d],batch no: %d/%d"
                  % (epoch,i+1,len(dataloader)))
             
             images = images.cuda()
             labels = labels.cuda()
             
             fake_hyper = generator.forward(images)
             loss_v = criterion_valid(fake_hyper, labels)
             losses.update(loss_v.data)
           print("\r [Total Loss: %.4f]" %
               (losses.avg))  

Thanks a lot for your help

The code is unfortunately not executable as some definitions are missing (e.g. Conv_ReLU_Block) as well as the data.

Sorry for the missing definition

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

and the Conv_ReLU_Block

class Conv_ReLU_Block(nn.Module): 
    def __init__(self, nFeat=64, ksize=3): 
        super(Conv_ReLU_Block, self).__init__()
        self.conv = nn.Conv2d(in_channels=nFeat, out_channels=nFeat, kernel_size=ksize, stride=1, padding=int((ksize - 1) / 2), bias=False)
        self.relu = nn.PReLU()

    def forward(self, x):
        return self.relu(self.conv(x))

and the losses
Loss functions

  criterion_L1 = torch.nn.L1Loss().cuda()
    criterion_valid = torch.nn.L1Loss().cuda()

regarding the data if you mean the dataset itself, it is quite large and here is kink for it
https://competitions.codalab.org/competitions/22225
and the offline code for creating the dataset

import os
import os.path
import h5py
import cv2
import glob
import numpy as np
import argparse
import hdf5storage 
import random
from scipy.io import loadmat

parser = argparse.ArgumentParser(description="SpectralSR")
parser.add_argument("--data_path", type=str, default='../../NTIRE2020', help="data path")
parser.add_argument("--out_data_path", type=str, default='./Dataset', help="out data path")
parser.add_argument("--patch_size", type=int, default=64, help="data patch size")
parser.add_argument("--stride", type=int, default=32, help="data patch stride")
opt = parser.parse_args()


def main():
     if not os.path.exists(opt.out_data_path):
        os.makedirs(opt.out_data_path)
     h5f = h5py.File('./Dataset/train_clean.h5', 'w')    
     

     process_data(h5f, patch_size=opt.patch_size, stride=opt.stride, mode='train')


def normalize(data, max_val, min_val):
    return (data-min_val)/(max_val-min_val)


def Im2Patch(img, win, stride=1):
    k = 0
    endc = img.shape[0]
    endw = img.shape[1]
    endh = img.shape[2]
    patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
    TotalPatNum = patch.shape[1] * patch.shape[2]
    Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
    for i in range(win):
        for j in range(win):
            patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
            Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
            k = k + 1
    return Y.reshape([endc, win, win, TotalPatNum])


def process_data(h5f,patch_size, stride, mode):
    if mode == 'train':
        print("\nprocess training set ...\n")
        patch_num = 1
        filenames_hyper =glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Train_Spectral', '*.mat'))
        filenames_rgb = glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Train_Clean', '*.png'))
        filenames_hyper.sort()
        filenames_rgb.sort()
        print(len(filenames_rgb),len(filenames_hyper))
        print("\nbefore loop ...\n")
        #for k in range(1):  # make small dataset
        for k in range(len(filenames_hyper)):
            print([filenames_hyper[k], filenames_rgb[k]])
            # load hyperspectral image
            #mat = h5py.File(filenames_hyper[k], 'r')
            mat = loadmat(filenames_hyper[k])
            hyper = np.float32(np.array(mat['cube']))
            hyper = np.transpose(hyper, [2, 0, 1])
            hyper = normalize(hyper, max_val=1., min_val=0.)
            # load rgb image
            rgb = cv2.imread(filenames_rgb[k])  # imread -> BGR model
            rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
            rgb = np.transpose(rgb, [2, 0, 1])
            rgb = normalize(np.float32(rgb), max_val=255., min_val=0.)
            # creat patches
            patches_hyper = Im2Patch(hyper, win=patch_size, stride=stride)
            patches_rgb = Im2Patch(rgb, win=patch_size, stride=stride)
            # add data :重组patches
            for j in range(patches_hyper.shape[3]):
                print("generate training sample #%d" % patch_num)
                sub_hyper = patches_hyper[:, :, :, j]
                sub_rgb = patches_rgb[:, :, :, j]
                data = np.concatenate((sub_hyper, sub_rgb), 0)
                h5f.create_dataset(str(patch_num), data=data)
                patch_num += 1

        print("\ntraining set: # samples %d\n" % (patch_num-1))


if __name__ == '__main__':
    main()


Thanks a lot for your help.
if you will download this dataset from the clean track, make sure that you remove the image number 340 since it has different distribution than the other image in the dataset.
if you need me to post the full-organized code let me know. I am really grateful for your help. Thanks a lot.

Thanks for the link. It seems I might need to register and download the large dataset.
Could you remove the dataset dependency and try to come up with a minimal, executable code snippet which would reproduce the issue?
I doubt the original dataset is needed and would assume that random tensors using the same shape could reproduce the issue (or just a single original sample).

Thanks for your reply. i wrote a script to generate random hyperspectral data and its corresponding rgb image (run this firs to generate the dataset in its folder). The code as well will generate cie_1964_w_gain.npz file will be used in the loss function. Copy this file in the location of your main file

import numpy as np
import hdf5storage ## you may install it using pip
import cv2 as cv  
import argparse
import os
from scipy.io import loadmat
from matplotlib import pyplot as plt
from os.path import basename,join,splitext
import torch
import torch.nn as nn
from PIL import Image


## This script creates a random hyperspectral image with the same width and height as the
## original dataset dimention # 31 channels (482,512,31)

img_width   = 482
img_height  = 512
img_channel = 31 
min_value   = 0.0001
filtersPath = "./NTIRE2020/cie_1964_w_gain.npz"
BIT_8 = 256
parser = argparse.ArgumentParser(description="SpectralSR")
parser.add_argument("--root", type=str, default='./NTIRE2020', help="hyper data path")
parser.add_argument("--data_path_hyper", type=str, default='./NTIRE2020_Train_Spectral', help="hyper data path")
parser.add_argument("--data_path_rgb", type=str, default='./NTIRE2020_Train_clean', help="rgb data path")
opt = parser.parse_args()
train_data_path = "./CleanResults/1"
## the path to save dataset

##create file path if not existed for hyper
if not os.path.exists(os.path.join(opt.root, opt.data_path_hyper)):
   os.makedirs(os.path.join(opt.root, opt.data_path_hyper))
##create file path if not existed for rgb
if not os.path.exists(os.path.join(opt.root, opt.data_path_rgb)):
   os.makedirs(os.path.join(opt.root, opt.data_path_rgb))

## create camera spectral curvers to create RGB images
def create_camera_curves():
    filters = np.array(
    [[ 0.41817229,  0.04383285,  1.88213184],
      [ 1.85425449 , 0.19147895 , 8.52029389],
      [ 4.47484197 , 0.46778508 ,21.28163131],
      [ 6.88603366 , 0.84577887 ,33.99399574],
      [ 8.39714516 , 1.35751927 ,43.04896615],
      [ 8.1119695  , 1.95625181 ,43.65117201],
      [ 6.61455659 , 2.80353959 ,38.19302491],
      [ 4.28065468 , 4.04979288 ,28.83148095],
      [ 1.76171245 , 5.54556362 ,16.89601022],
      [ 0.35388741 , 7.41626659 , 9.08678754],
      [ 0.08350447 ,10.07641566 , 4.78136574],
      [ 0.81983625 ,13.26840209 , 2.45180064],
      [ 2.57666885 ,16.6583405  , 1.32846351],
      [ 5.17506725 ,19.13938808 , 0.66634341],
      [ 8.24479763 ,21.03705468 , 0.2992648 ],
      [11.59403605 ,21.68813996 , 0.08726733],
      [15.43222205 ,21.81014328 , 0.        ],
      [19.22736473 ,20.89631022 , 0.        ],
      [22.19258323 ,19.00212068 , 0.        ],
      [24.47626429 ,17.00053586 , 0.        ],
      [24.59596279 ,14.39680704 , 0.        ],
      [22.54970928 ,11.54566013 , 0.        ],
      [18.73811079  ,8.70483506 , 0.        ],
      [14.16834157  ,6.19951365 , 0.        ],
      [ 9.44386149  ,3.93253499 , 0.        ],
      [ 5.87176941  ,2.35375213 , 0.        ],
      [ 3.33860341  ,1.31824378 , 0.        ],
      [ 1.77820327  ,0.6954211  , 0.        ],
      [ 0.89392678  ,0.3478177  , 0.        ],
      [ 0.43636996  ,0.16945318 , 0.        ],
      [ 0.20956822  ,0.0813007  , 0.        ]])

    bands = np.array([[400 ,410, 420, 430, 440, 450 ,460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570,
      580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700]])
    
    #np.savez(filtersPath, filters=filters, bands=bands)
    np.savez("./NTIRE2020/cie_1964_w_gain.npz", filters=filters,bands=bands)

create_camera_curves()

def create_rgb(hyper_img, filtersPath):
    model_hs2rgb = nn.Conv2d(31, 3, 1, bias=False)
    cie_matrix = np.load(filtersPath)['filters']
    cie_matrix = torch.from_numpy(np.transpose(cie_matrix, [1, 0])).unsqueeze(-1).unsqueeze(-1).float()
    model_hs2rgb.weight.data = cie_matrix
    with torch.no_grad():
       hyper_tensor = torch.tensor(np.transpose(hyper_img, [2, 0, 1]),dtype=torch.float32)
       C,W,H = hyper_tensor.shape
       rgb_tensor = model_hs2rgb(hyper_tensor.view(1,C,W,H))
       rgb_tensor = rgb_tensor / 255
       rgb_tensor = torch.clamp(rgb_tensor, 0, 1) * 255
       rgb_tensor = rgb_tensor.squeeze(0)
       rgb_img = rgb_tensor.numpy()
       rgb_img = np.transpose(rgb_img,[1,2,0])
    return rgb_img   
    
    
## the main loop to create the dataset
for i in range(1,450):
   print("creating dataset No.",i) 
   hyper_img = np.random.rand(img_width,img_height,img_channel)
   ## this step to make the min value not less than 0.0001 to make the train stable
   hyper_img[hyper_img < min_value] = min_value
   train_data_path_hyper = os.path.join(opt.root, opt.data_path_hyper, 'train'+str(i)+'.mat')
   train_data_path_rgb = os.path.join(opt.root, opt.data_path_rgb)
   hdf5storage.savemat(train_data_path_hyper, {'cube': hyper_img}, format='5')
   # Project image to RGB
   rgbIm = np.true_divide(create_rgb(hyper_img, filtersPath), BIT_8)
   # Save image file
   # save RGB image
   cv.imwrite(os.path.join(train_data_path_rgb ,'train'+str(i)+'.png'), (rgbIm * 255).astype(np.uint8))

Run the next script to generate the image batches offline at first. Adjust the paths to the NTIRE2020 folder that holds the data. The file name istrain_data_proprocess

import os
import os.path
import h5py
import cv2
import glob
import numpy as np
import argparse
import hdf5storage 
import random
from scipy.io import loadmat

parser = argparse.ArgumentParser(description="SpectralSR")
parser.add_argument("--data_path", type=str, default='../../../NTIRE2020', help="data path")
parser.add_argument("--out_data_path", type=str, default='./Dataset', help="out data path")
parser.add_argument("--patch_size", type=int, default=64, help="data patch size")
parser.add_argument("--stride", type=int, default=32, help="data patch stride")
opt = parser.parse_args()


def main():
     if not os.path.exists(opt.out_data_path):
        os.makedirs(opt.out_data_path)
     h5f = h5py.File('./Dataset/train_clean.h5', 'w')    
     

     process_data(h5f, patch_size=opt.patch_size, stride=opt.stride, mode='train')


def normalize(data, max_val, min_val):
    return (data-min_val)/(max_val-min_val)


def Im2Patch(img, win, stride=1):
    k = 0
    endc = img.shape[0]
    endw = img.shape[1]
    endh = img.shape[2]
    patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
    TotalPatNum = patch.shape[1] * patch.shape[2]
    Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
    for i in range(win):
        for j in range(win):
            patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
            Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
            k = k + 1
    return Y.reshape([endc, win, win, TotalPatNum])


def process_data(h5f,patch_size, stride, mode):
    if mode == 'train':
        print("\nprocess training set ...\n")
        patch_num = 1
        filenames_hyper =glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Train_Spectral', '*.mat'))
        filenames_rgb = glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Train_Clean', '*.png'))
        filenames_hyper.sort()
        filenames_rgb.sort()
        print(len(filenames_rgb),len(filenames_hyper))
        print("\nbefore loop ...\n")
        #for k in range(1):  # make small dataset
        for k in range(len(filenames_hyper)):
            print([filenames_hyper[k], filenames_rgb[k]])
            # load hyperspectral image
            #mat = h5py.File(filenames_hyper[k], 'r')
            mat = loadmat(filenames_hyper[k])
            hyper = np.float32(np.array(mat['cube']))
            hyper = np.transpose(hyper, [2, 0, 1])
            hyper = normalize(hyper, max_val=1., min_val=0.)
            # load rgb image
            rgb = cv2.imread(filenames_rgb[k])  # imread -> BGR model
            rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
            rgb = np.transpose(rgb, [2, 0, 1])
            rgb = normalize(np.float32(rgb), max_val=255., min_val=0.)
            # creat patches
            patches_hyper = Im2Patch(hyper, win=patch_size, stride=stride)
            patches_rgb = Im2Patch(rgb, win=patch_size, stride=stride)
            # add data :重组patches
            for j in range(patches_hyper.shape[3]):
                print("generate training sample #%d" % patch_num)
                sub_hyper = patches_hyper[:, :, :, j]
                sub_rgb = patches_rgb[:, :, :, j]
                data = np.concatenate((sub_hyper, sub_rgb), 0)
                h5f.create_dataset(str(patch_num), data=data)
                patch_num += 1

        print("\ntraining set: # samples %d\n" % (patch_num-1))


if __name__ == '__main__':
    main()


Copy ten images and creates the two folders named as NTIRE2020_Validation_Spectral and NTIRE2020_Validation_Clean for the validation.The script name is valid_data_preprocess

import os
import os.path
import h5py
from scipy.io import loadmat,savemat
import cv2
import glob
import numpy as np
import argparse
import hdf5storage


parser = argparse.ArgumentParser(description="SpectralSR")
parser.add_argument("--data_path", type=str, default='../../../NTIRE2020', help="data path")
#parser.add_argument("--data_path", type=str, default='./NTIRE2020', help="data path")
parser.add_argument("--patch_size", type=int, default=64, help="data patch size")
parser.add_argument("--stride", type=int, default=32, help="data patch stride")
parser.add_argument("--out_data_path", type=str, default='./Dataset', help="out data path")

opt = parser.parse_args()


def main():
    if not os.path.exists(opt.out_data_path):
        os.makedirs(opt.out_data_path)
    
    h5f = h5py.File('./Dataset/test_final.h5', 'w')    
    process_data(h5f,patch_size=opt.patch_size, stride=opt.stride, mode='valid')


def normalize(data, max_val, min_val):
    return (data-min_val)/(max_val-min_val)


def Im2Patch(img, win, stride=1):
    k = 0
    endc = img.shape[0]
    endw = img.shape[1]
    endh = img.shape[2]
    patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
    TotalPatNum = patch.shape[1] * patch.shape[2]
    Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
    for i in range(win):
        for j in range(win):
            patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
            Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
            k = k + 1
    return Y.reshape([endc, win, win, TotalPatNum])


def process_data(h5f,patch_size, stride, mode):
    if mode == 'valid':
        print("\nprocess valid set ...\n")
        patch_num = 1
        filenames_hyper = glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Validation_Spectral', '*.mat'))
        filenames_rgb = glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Validation_Clean', '*.png'))
        filenames_hyper.sort()
        filenames_rgb.sort()
        #for k in range(1):  # make small dataset
        for k in range(len(filenames_hyper)):
            # continue
            print([filenames_hyper[k], filenames_rgb[k]])
            # load hyperspectral image
            mat = hdf5storage.loadmat(filenames_hyper[k])
            #mat = h5py.File(filenames_hyper[k], 'r')
            hyper = np.float32(np.array(mat['cube']))
            hyper = np.transpose(hyper, [2, 0, 1])
            hyper = normalize(hyper, max_val=1., min_val=0.)
            # load rgb image
            rgb = cv2.imread(filenames_rgb[k])  # imread -> BGR model
            rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
            rgb = np.transpose(rgb, [2, 0, 1])
            rgb = normalize(np.float32(rgb), max_val=255., min_val=0.)
            # creat patches
            patches_hyper = Im2Patch(hyper, win=patch_size, stride=stride)
            patches_rgb = Im2Patch(rgb, win=patch_size, stride=stride)
            # add data :重组patches
            for j in range(patches_hyper.shape[3]):
                print("generate valid sample #%d" % patch_num)
                sub_hyper = patches_hyper[:, :, :, j]
                sub_rgb = patches_rgb[:, :, :, j]

                data = np.concatenate((sub_hyper, sub_rgb), 0)
                h5f.create_dataset(str(patch_num), data=data)
              
                
                patch_num += 1

        print("\ntraining set: # samples %d\n" % (patch_num-1))


if __name__ == '__main__':
    main()

The dataset file class interface is called dataset.py is as folllow: as asimple hack to make my code run is to not load the entire dataset batches at once. to do this remove the comment on self.len=32000 and comment the line below

import os
import random
import h5py
import numpy as np
import torch
import torch.utils.data as udata


class HyperDataset(udata.Dataset):
    def __init__(self, mode='train'):
        self.mode = mode

        if self.mode == 'train':
            self.h5f = h5py.File('./Dataset/train_clean.h5', 'r')
        elif self.mode == 'test':
            self.h5f = h5py.File('./Dataset/test_final.h5', 'r')
        

        #self.keys = list(self.h5f.keys())
        if 'train' in self.mode:
            self.keys = list(self.h5f.keys())
            random.shuffle(self.keys)
            #self.len = 32000
            self.len = len(self.keys)
        else:
            self.keys = list(self.h5f.keys())
            self.keys.sort()
            self.len = len(self.keys)
            
       
    def __len__(self):
        #return len(self.keys)
         return self.len

    def __getitem__(self, index):
        key = str(self.keys[index])
        data = np.array(self.h5f[key])
        data = torch.Tensor(data)
        return data[31:34,:,:], data[0:31,:,:]

    

    def close(self):
        self.h5f.close()

    def shuffle(self):
        if 'train' in self.mode:
            random.shuffle(self.keys)

The main file is called main.py

import torch
import torch.nn as nn
import argparse
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
#from torch.autograd import Variable
import os
import time
#import random
#from dataset import HyperDatasetValid, HyperDatasetTrain1, HyperDatasetTrain2, HyperDatasetTrain3, HyperDatasetTrain4  # Clean Data set
import dataset # Clean Data set
from AWAN import AWAN
from utils import AverageMeter, initialize_logger, save_checkpoint, record_loss, LossTrainCSS, Loss_valid
import visdom
from train import train,test
from torch_poly_lr_decay import PolynomialLRDecay
#from BackBone import BackBone
#from RSCAN import BackBone,SpatialSpectralSRNet
#from FMNet import FMNet
#from collections import OrderedDict
#from proposed import SpatialSpectralSRNet
#from SAN import SAN


os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

parser = argparse.ArgumentParser(description="SSR")
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
parser.add_argument("--end_epoch", type=int, default=150+1, help="number of epochs")
parser.add_argument("--init_lr", type=float, default=1e-4, help="initial learning rate")
parser.add_argument("--decay_power", type=float, default=1.5, help="decay power")
parser.add_argument("--trade_off", type=float, default=10, help="trade_off")
parser.add_argument("--max_iter", type=float, default=3000000, help="max_iter")  # patch48:380x450/32x100-534375; patch96:82x450/32x100-113906
parser.add_argument("--outf", type=str, default="CleanResults", help='path log files')
parser.add_argument('--b1', type = float, default = 0.9, help = 'Adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type = float, default = 0.999, help = 'Adam: decay of second order momentum of gradient')
parser.add_argument('--weight_decay', type = float, default = 0, help = 'weight decay for optimizer')
parser.add_argument("--milestones", type=list, default=list(range(30, 150, 30)), help="how many epoch to reduce the lr")
parser.add_argument("--gamma", type=int, default=0.5, help="how much to reduce the lr each time")
opt = parser.parse_args()


def main():
    cudnn.benchmark = True

    # load dataset
    print("\nloading dataset ...")
    
    print("\nloading dataset ...")
    train_dataset = dataset.HyperDataset(mode='train')
    test_dataset  = dataset.HyperDataset(mode='test')
    print("Train_dataset:%d" % (len(train_dataset)))
    print("Validation set samples:", len(test_dataset))
    # Data Loader (Input Pipeline)
    train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader   = DataLoader(dataset=test_dataset, batch_size=1,  shuffle=False, num_workers=0, pin_memory=True)
    
    
    #torch.autograd.set_detect_anomaly(True)
    viz = visdom.Visdom(env="proposed-model channel6")
    if not viz.check_connection():
        print("Visdom is not connected. Did you run 'python -m visdom.server' ?")
    # model
    print("\nbuilding models_baseline ...")
    model = AWAN(3, 31, 200, 8)
    #model = SAN(3,128,31,6,3)
    #model = BackBone(3,31,128,6)
    #model = BackBone(3,31,128,5,8)
    #model = FMNet(bNum=3, nblocks=5, input_channels=31, num_features=64, out_channels=31)
    #model = SpatialSpectralSRNet(in_channels=3, out_channels=31, n_channels=64, n_blocks=7, kernel_size=3, upscale_factor=2)
    print('Parameters number is ', sum(param.numel() for param in model.parameters()))
    criterion_train = LossTrainCSS()
    criterion_train_L1 = torch.nn.L1Loss().cuda()

    criterion_valid = Loss_valid()
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)  # batchsize integer times
    if torch.cuda.is_available():
        model.cuda()
        criterion_train.cuda()
        criterion_valid.cuda()                                     

    # Parameters, Loss and Optimizer
    start_epoch = 0
    iteration = 0
    record_val_loss = 1000
    #optimizer = optim.Adam(model.parameters(), lr=opt.init_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    optimizer = optim.Adam(model.parameters(), lr = opt.init_lr, betas = (opt.b1, opt.b2), weight_decay = opt.weight_decay)
    #lr_scheduler = PolynomialLRDecay(optimizer, max_decay_steps=opt.max_iter, end_learning_rate=0.0000001, power=1.5)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, opt.milestones, opt.gamma)
    # visualzation
    if not os.path.exists(opt.outf):
        os.makedirs(opt.outf)
    loss_csv = open(os.path.join(opt.outf, 'loss.csv'), 'a+')
    log_dir = os.path.join(opt.outf, 'train.log')
    logger = initialize_logger(log_dir)

    # Resume
    resume_file = opt.outf + '/best_net_7epoch.pth'
    #resume_file = ''
    if resume_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            iteration = checkpoint['iter']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

    # start epoch
    for epoch in range(start_epoch, opt.end_epoch):
        start_time = time.time()
        total_loss, train_loss, rgb_train_loss,iteration = train(train_loader, model, criterion_train,criterion_train_L1, optimizer, epoch, lr_scheduler, opt)

        lr_scheduler.step()
        val_loss = test(model, val_loader, criterion_valid)
        # Save model either the best model so far or every 10 epochs 
        if torch.abs(val_loss - record_val_loss) < 0.0001 or val_loss < record_val_loss:
            save_checkpoint(opt.outf, epoch, iteration, model, optimizer,best=True)
            if val_loss < record_val_loss:
                record_val_loss = val_loss
        if epoch %5==0:
            save_checkpoint(opt.outf, epoch, iteration, model, optimizer,best=False)
        # print loss
        end_time = time.time()
        epoch_time = end_time - start_time
        print("Epoch [%02d], Iter[%06d], Time:%.9f, Train Loss: %.9f Test Loss: %.9f "
              % (epoch, iteration, epoch_time, train_loss, val_loss))
        
        
        #for learning rate
        viz.line([optimizer.param_groups[0]['lr']*(10**4)],[epoch],win='Learning rate schedule',update='append',
          opts=dict(title='Learning rate schedule',
                          legend=['lr*(10^4)'])) 
        #for HSI train loss
        viz.line([train_loss.detach().cpu()],[epoch],win='HSI Train_loss', 
                  update='append',opts=dict(title=' HSI Train Learning Curve.',
                                            legend=['HSI Train Loss']))
        #for validation loss
        viz.line([val_loss.detach().cpu()],[epoch],win='Val Train_loss', 
                  update='append',opts=dict(title='val loss Learning Curve.',
                                            legend=['val loss']))
        #for rgb train loss
        viz.line([rgb_train_loss.detach().cpu()],[epoch],win=' RGB Train_loss', 
                  update='append',opts=dict(title='rgb train loss Learning Curve.',
                                            legend=['rgb train loss']))
        
        #for total train loss
        viz.line([total_loss.detach().cpu()],[epoch],win='Total Train_loss', 
                  update='append',opts=dict(title='total train loss Learning Curve.',
                                            legend=['total train loss']))
        
        # for train_loss and validation_loss
        viz.line([[train_loss.detach().cpu(),val_loss.detach().cpu()]],[epoch],win='Train_loss and val_loss', 
                  update='append',opts=dict(title='Learning Curve.',
                                            legend=['Train Loss', 'Validation Loss']))
        
        # save loss
        record_loss(loss_csv,epoch, train_loss, val_loss)
        logger.info("Epoch [%02d], Train Loss: %.9f Test Loss: %.9f "
                    % (epoch, train_loss, val_loss))



if __name__ == '__main__':
    main()
    print(torch.__version__)

The train.py holds the train and test functions

import torch
from utils import AverageMeter, initialize_logger, save_checkpoint, record_loss, LossTrainCSS, Loss_valid
import datetime
import os
import time
import random

log_dir2 = os.path.join("CleanResults", 'error.log')
logger2 = initialize_logger(log_dir2)

def train(train_loader, model, criterion_train,criterion_train_L1, optimizer, epoch, lr_scheduler, opt):

    
    total_loss =  AverageMeter()
    losses = AverageMeter()
    losses_rgb = AverageMeter()
    #random.shuffle(train_loader)
    prev_time = time.time()
    model.train()
    for i,data  in enumerate(train_loader):
      #with torch.autograd.set_detect_anomaly(True):  
        images, labels = data
        # to only test the labels that having values other than 0 and 1         
        if labels.min()<0 or labels.max()>1:
            print("yes there is problem in labels ")
            logger2.info(" IMAGBTRain Epoch [%02d],batch no: %d/%d"
            % (epoch,i+1,len(train_loader)))
        if images.min()<0 or images.max()>1:
             print("yes there is problem in labels ")
             logger2.info("IMAGATRAIN Epoch [%02d],batch no: %d/%d"
             % (epoch,i+1,len(train_loader)))      
        
        images, labels = images.cuda(), labels.cuda()
            

        model.zero_grad()
        optimizer.zero_grad()

        # #lr_scheduler.step()
        fake_hyper  = model.forward(images)
        #loss = criterion_train_L1(fake_hyper, labels)
        loss , loss_rgb = criterion_train(fake_hyper, labels, images)
          
        loss_all = loss + opt.trade_off * loss_rgb
        loss_all.backward()
        optimizer.step()
        # # Determine approximate time left
        iters_done = epoch *len(train_loader) + i
        iters_left =opt.end_epoch*len(train_loader) - iters_done
        time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time))
        prev_time = time.time()
        #lr_scheduler.step()
        print('[Epoch:%02d],[Batch NO:%d/%d],[iter:%d],[Time_left=%s]'
                       % (epoch, i+1, len(train_loader), iters_done, time_left))
        ##  record loss
        losses.update(loss.data)
        losses_rgb.update(loss_rgb.data)
        total_loss.update(loss_all.data)
        print('[Epoch:%02d],[Batch NO:%d/%d],[iter:%d],[Time_left=%s],[train_losses.avg=%.9f], [rgb_train_losses.avg=%.9f]'
                  % (epoch, i+1, len(train_loader), iters_done, time_left,losses.avg, losses_rgb.avg))
    return total_loss.avg, losses.avg,losses_rgb.avg  ,iters_done  

def test(model, test_dataset, criterion):
    
    model.eval()
    losses = AverageMeter()
    for i, data in enumerate(test_dataset):
        images,labels = data
        if labels.min()<0 or labels.max()>1:
            print("yes there is problem in labels ")
            logger2.info(" IMAGBVAL")
        if images.min()<0 or images.max()>1:
             print("yes there is problem in labels ")
             logger2.info("IMAGAVAL")
        
        
        images, labels = images.cuda(), labels.cuda()
        with torch.no_grad():
           fake_hyper = model.forward(images)
           loss = criterion(fake_hyper, labels)
           losses.update(loss.data)
    return losses.avg 

# Learning rate
def poly_lr_scheduler(optimizer, init_lr, iteraion, lr_decay_iter=1, max_iter=100, power=0.9):
    """Polynomial decay of learning rate
        :param init_lr is base learning rate
        :param iter is a current iteration
        :param lr_decay_iter how frequently decay occurs, default is 1
        :param max_iter is number of maximum iterations
        :param power is a polymomial power

    """
    if iteraion % lr_decay_iter or iteraion > max_iter:
        return optimizer

    lr = init_lr*(1 - iteraion/max_iter)**power
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr    

The utils.py

from __future__ import division

import torch
import torch.nn as nn
import logging
import numpy as np
import os
import hdf5storage


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def initialize_logger(file_dir):
    logger = logging.getLogger()
    fhandler = logging.FileHandler(filename=file_dir, mode='a')
    formatter = logging.Formatter('%(asctime)s - %(message)s',"%Y-%m-%d %H:%M:%S")
    fhandler.setFormatter(formatter)
    logger.addHandler(fhandler)
    logger.setLevel(logging.INFO)
    return logger


def save_checkpoint(model_path, epoch, iteration, model, optimizer,best=True):
    state = {
            'epoch': epoch,
            'iter': iteration,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            }
    if best == False:
     torch.save(state, os.path.join(model_path, 'net_%depoch.pth' % epoch))
    else:
     torch.save(state, os.path.join(model_path, 'best_net_%depoch.pth' % epoch))


def save_matv73(mat_name, var_name, var):
    hdf5storage.savemat(mat_name, {var_name: var}, format='7.3', store_python_metadata=True)


def record_loss(loss_csv,epoch, train_loss, test_loss):
    """ Record many results."""
    loss_csv.write('{},{},{}\n'.format(epoch, train_loss, test_loss))
    loss_csv.flush()    
    loss_csv.close


class Loss_train(nn.Module):
    def __init__(self):
        super(Loss_train, self).__init__()

    def forward(self, outputs, label):
        error = torch.abs(outputs - label) / label
        # error = torch.abs(outputs - label)
        rrmse = torch.mean(error.view(-1))
        return rrmse


class Loss_valid(nn.Module):
    def __init__(self):
        super(Loss_valid, self).__init__()

    def forward(self, outputs, label):
        outputs1 = outputs.clone()
        label1 = label.clone()
        error = torch.abs(outputs1 - label1) / label1
        # error = torch.abs(outputs - label)
        rrmse = torch.mean(error.view(-1))
        return rrmse


class LossTrainCSS(nn.Module):
    def __init__(self):
        super(LossTrainCSS, self).__init__()
        self.model_hs2rgb = nn.Conv2d(31, 3, 1, bias=False)
        filtersPath = './cie_1964_w_gain.npz'
        cie_matrix = np.load(filtersPath)['filters']
        cie_matrix = torch.from_numpy(np.transpose(cie_matrix, [1, 0])).unsqueeze(-1).unsqueeze(-1).float()
        self.model_hs2rgb.weight.data = cie_matrix

    def forward(self, outputs, label, rgb_label):
        rrmse = self.mrae_loss(outputs, label)
        # hs2rgb
        with torch.no_grad():
            rgb_tensor = self.model_hs2rgb(outputs)
            rgb_tensor = rgb_tensor / 255
            rgb_tensor = torch.clamp(rgb_tensor, 0, 1) * 255
            # rgb_tensor = torch.tensor(rgb_tensor, dtype=torch.uint8)
            # rgb_tensor = torch.tensor(rgb_tensor, dtype=torch.uint8)
            # update from torch it self is the line below , the original line is below 
            # the written one
            rgb_tensor = rgb_tensor.clone().detach().byte().float()
            #rgb_tensor = torch.tensor(rgb_tensor).byte().float()
            rgb_tensor = rgb_tensor / 255
        rrmse_rgb = self.rgb_mrae_loss(rgb_tensor, rgb_label)
        return rrmse, rrmse_rgb

    def mrae_loss(self, outputs, label):

        error = torch.abs(outputs - label) / label
        mrae = torch.mean(error.view(-1))
        return mrae

    def rgb_mrae_loss(self, outputs, label):
        outputs1 = outputs.clone()
        label1 = label.clone()
        error = torch.abs(outputs1 - label1)
        mrae = torch.mean(error.view(-1))
        return mrae

if you need script to check that the dataset generated after the cropping not changed

h5f = h5py.File('./Dataset/train_clean.h5', 'r') 
for i,key in enumerate(h5f.keys()):
    data = np.array(h5f[key])  
    if i%1000 == 0:
        print(i)
    if  data[0:31,:,:].min() <0 or  data[0:31,:,:].max()>1:
        print("yes there are dude")
        print(key, "with min: {0} and max:{1}".format(data[0:31,:,:].min(),data[0:31,:,:].max()))

    if torch.isnan(torch.tensor(data[0:31,:,:])).sum() > 0:
        print("yes damn nan")
    if torch.isinf(torch.tensor(data[0:31,:,:])).sum() > 0:
        print("yes damn inf")

Note that when i didnot move he labels to cuda nothing happenand if i move it, the labels changed even if i removed the loss and backward step. The gpu i am using is RTX 2080 Ti and torch version '1.10.0'

The code for the model in file called AWAN.py and you can replace it with any other model all yeild the same behavior

import torch
from torch import nn
from torch.nn import functional as F


class AWCA(nn.Module):
    def __init__(self, channel, reduction=16):
        super(AWCA, self).__init__()
        self.conv = nn.Conv2d(channel, 1, 1, bias=False)
        self.softmax = nn.Softmax(dim=2)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.PReLU(),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, h, w = x.size()
        input_x = x
        input_x = input_x.view(b, c, h*w).unsqueeze(1)

        mask = self.conv(x).view(b, 1, h*w)
        mask = self.softmax(mask).unsqueeze(-1)
        y = torch.matmul(input_x, mask).view(b, c)

        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class NONLocalBlock2D(nn.Module):
    def __init__(self, in_channels, reduction=8, dimension=2, sub_sample=False, bn_layer=False):
        super(NONLocalBlock2D, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = self.in_channels // reduction

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0, bias=False)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0, bias=False),
                bn(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0, bias=False)
            nn.init.constant_(self.W.weight, 0)
            # nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0, bias=False)
        # self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
        #                    kernel_size=1, stride=1, padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

    def forward(self, x):
        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        # phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        # f = torch.matmul(theta_x, phi_x)
        f = self.count_cov_second(theta_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z

    def count_cov_second(self, input):
        x = input
        batchSize, dim, M = x.data.shape
        x_mean_band = x.mean(2).view(batchSize, dim, 1).expand(batchSize, dim, M)
        y = (x - x_mean_band).bmm(x.transpose(1, 2)) / M
        return y


class PSNL(nn.Module):
    def __init__(self, channels):
        super(PSNL, self).__init__()
        # nonlocal module
        self.non_local = NONLocalBlock2D(channels)

    def forward(self,x):
        # divide feature map into 4 part
        batch_size, C, H, W = x.shape
        H1 = int(H / 2)
        W1 = int(W / 2)
        nonlocal_feat = torch.zeros_like(x)

        feat_sub_lu = x[:, :, :H1, :W1]
        feat_sub_ld = x[:, :, H1:, :W1]
        feat_sub_ru = x[:, :, :H1, W1:]
        feat_sub_rd = x[:, :, H1:, W1:]

        nonlocal_lu = self.non_local(feat_sub_lu)
        nonlocal_ld = self.non_local(feat_sub_ld)
        nonlocal_ru = self.non_local(feat_sub_ru)
        nonlocal_rd = self.non_local(feat_sub_rd)
        nonlocal_feat[:, :, :H1, :W1] = nonlocal_lu
        nonlocal_feat[:, :, H1:, :W1] = nonlocal_ld
        nonlocal_feat[:, :, :H1, W1:] = nonlocal_ru
        nonlocal_feat[:, :, H1:, W1:] = nonlocal_rd

        return nonlocal_feat


class Conv3x3(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size, stride, dilation=1):
        super(Conv3x3, self).__init__()
        reflect_padding = int(dilation * (kernel_size - 1) / 2)
        self.reflection_pad = nn.ReflectionPad2d(reflect_padding)
        self.conv2d = nn.Conv2d(in_dim, out_dim, kernel_size, stride, dilation=dilation, bias=False)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class DRAB(nn.Module):
    def __init__(self, in_dim, out_dim, res_dim, k1_size=3, k2_size=1, dilation=1):
        super(DRAB, self).__init__()
        self.conv1 = Conv3x3(in_dim, in_dim, 3, 1)
        self.relu1 = nn.PReLU()
        self.conv2 = Conv3x3(in_dim, in_dim, 3, 1)
        self.relu2 = nn.PReLU()
        # T^{l}_{1}: (conv.)
        self.up_conv = Conv3x3(in_dim, res_dim, kernel_size=k1_size, stride=1, dilation=dilation)
        self.up_relu = nn.PReLU()
        self.se = AWCA(res_dim)
        # T^{l}_{2}: (conv.)
        self.down_conv = Conv3x3(res_dim, out_dim, kernel_size=k2_size, stride=1)
        self.down_relu = nn.PReLU()

    def forward(self, x, res):
        x_r = x
        out = self.relu1(self.conv1(x))
        out = self.conv2(out)
        out = out + x_r
        out = self.relu2(out)
        # T^{l}_{1}
        out = self.up_conv(out)
        out = out + res
        out = self.up_relu(out)
        res = out
        out = self.se(out)
        # T^{l}_{2}
        out = self.down_conv(out)
        out = out + x_r
        out = self.down_relu(out)
        return out, res


class AWAN(nn.Module):
    def __init__(self, inplanes=3, planes=31, channels=200, n_DRBs=8):
        super(AWAN, self).__init__()
        # 2D Nets
        self.input_conv2D = Conv3x3(inplanes, channels, 3, 1)
        self.input_prelu2D = nn.PReLU()
        self.head_conv2D = Conv3x3(channels, channels, 3, 1)

        self.backbone = nn.ModuleList(
            [DRAB(in_dim=channels, out_dim=channels, res_dim=channels, k1_size=5, k2_size=3, dilation=1) for _ in
             range(n_DRBs)])

        self.tail_conv2D = Conv3x3(channels, channels, 3, 1)
        self.output_prelu2D = nn.PReLU()
        self.output_conv2D = Conv3x3(channels, planes, 3, 1)
        self.tail_nonlocal = PSNL(planes)

    def forward(self, x):
        out = self.DRN2D(x)
        return out

    def DRN2D(self, x):
        out = self.input_prelu2D(self.input_conv2D(x))
        out = self.head_conv2D(out)
        residual = out
        res = out

        for i, block in enumerate(self.backbone):
            out, res = block(out, res)

        out = self.tail_conv2D(out)
        out = torch.add(out, residual)
        out = self.output_conv2D(self.output_prelu2D(out))
        out = self.tail_nonlocal(out)
        return out


if __name__ == "__main__":
    # import os
    # os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    input_tensor = torch.rand(1, 3, 64, 64)
    model = AWAN(3, 31, 200, 10)
    # model = nn.DataParallel(model).cuda()
    with torch.no_grad():
        output_tensor = model(input_tensor)
    print(output_tensor.size())
    print('Parameters number is ', sum(param.numel() for param in model.parameters()))
    print(torch.__version__)

Sorry for the long scripts and thanks a lot.