I am working on a problem of spectral super-resolution where the inputs to the models are both rgb image with 3 channel (input image) and hyperspectral image with 31 channel (the labels to compare the output with).
At the training phase the pixel values of the labels changes without any reason.
Please any help regarding this issue.
Could you explain your use case a bit more and especially how the targets are used?
I would assume they are only passed to the loss function so I wouldn’t know how they can be changed.
Thanks a lot for your concern.
I am using a model to reconstruct hyperspectral image from RGB image. The idea is that; the input to the model is RGB image with Nx3xWxH and the output is hyperspectral image with dimension NxC=31xWxH image. The used labels is ground truth hyperspectral image with NxC=31xWxH. I only used the labels in two position when fetching the data and move it to the gpu and when computing L1 loss between the reconstructed and ground truth images. The problem that i am facing is that the label’s pixel values changes during training until it have nan values. I think some of the images is corrupted some how and i am checking it now. I need to know the following:
1- is it possible that labels can be changed any way.
2- the training is stable under condition that not switching the model mode to train or validate by model.train()
or model.eval()
.
3- the training is fine as long as i am not separate the code to train function test function which is quite weird. Sorry for the long post
Thanks again a lot for your concern.
-
Yes, since the
labels
are a tensor you can of course change a tensor. However, since they should be used in the loss calculation only, I wouldn’t know where they are changed without seeing code. -
I don’t see a question here. If the training is “unstable” when
eval()
is called, I assume you are seeing a higher loss value which you wouldn’t expect? If so, I would guess the some norm layers (e.g. batchnorm layers) are updating the internal stats with noisy batch stats updates (we have a lot of similar discussions about these issues). -
I also don’t see a question here, so you would need to describe the issue more.
Thanks a lot for your reply. Please i need to know the following: I did very interesting experiment (all the experiments done without any calculations of the loss or even any backpropagation);
1- i divided my training set to four even parts and train every part individually to see if the training is stable and the labels doesn’t change or not and the conclusion in this case is the train loop is stable.
2- i used the half of my training set with the same setting as mentioned above and the training loop is stable.
3- if i increase the trainset more than half; the labels changing even when there are no loss function or backpropagation.
4- if i make the model very simple the training is stable.
5- if the labels not moved to the cuda the training is stable.
I came to conclusion that the model and the data (labels) occupy more space that the gpu cannot handle, so there are memory leak or something, and that effect the labels not rgb images since the size of labels are very big >>25 gigs. Is that conclusion right??? if you need me to post any code just let me know.
Thanks a lot for your help and patience
No, I don’t think the conclusion is right, as a memory leak wouldn’t show up as a memory corruption.
Yes, a minimal, executable code snippet showing this behavior would be great to debug this issue.
Thanks for the reply.
The code for the simple model is as follow:
class Residual_Block(nn.Module):
def __init__(self, Cn=64, ksize=3):
super(Residual_Block, self).__init__()
self.conv = self.make_layer(Conv_ReLU_Block, conv_num=1, cn=Cn)
self.ouput = nn.Conv2d(in_channels=Cn, out_channels=Cn, kernel_size=ksize, stride=1, padding=int((ksize - 1) / 2), bias=False)
self.relu = nn.PReLU()
def make_layer(self, block, conv_num, cn):
layer = []
for _ in range(conv_num):
layer.append(block(cn))
return nn.Sequential(*layer)
def forward(self, x):
out = self.conv(x)
out = self.ouput(out)
out = out + x
# out = self.relu(out)
return out
class FMNet2(nn.Module):
def __init__(self,in_channels=3,channels=128,out_channels=31):
super(FMNet2, self).__init__()
self.conv0 = nn.Conv2d(in_channels, channels, kernel_size=3,padding=1,stride=1, bias=False)
self.res1 = Residual_Block(Cn=channels)
self.res2 = Residual_Block(Cn=channels)
self.res3 = Residual_Block(Cn=channels)
self.res4 = Residual_Block(Cn=channels)
self.res5 = Residual_Block(Cn=channels)
self.res6 = Residual_Block(Cn=channels)
self.res7 = Residual_Block(Cn=channels)
self.res8 = Residual_Block(Cn=channels)
self.res9 = Residual_Block(Cn=channels)
self.conv1 = nn.Conv2d(channels, out_channels, kernel_size=3,padding=1,stride=1, bias=False)
def forward(self,x):
print(x.shape)
out = self.conv0(x)
print(out.shape)
res = out
out = self.res1(out)
out = self.res2(out)
out = self.res3(out)
out = self.res4(out)
out = self.res5(out)
out = self.res6(out)
out = self.res7(out)
out = self.res8(out)
out = self.res9(out)
out = out + res
out = self.conv1(out)
return out
The code for the dataset class is as follow:
class HyperDataset(udata.Dataset):
def __init__(self, mode='train'):
self.mode = mode
if self.mode == 'train':
self.h5f = h5py.File('./Dataset/train_clean.h5', 'r')
elif self.mode == 'test':
self.h5f = h5py.File('./Dataset/test_final.h5', 'r')
#self.keys = list(self.h5f.keys())
if 'train' in self.mode:
self.keys = list(self.h5f.keys())
random.shuffle(self.keys)
self.len = len(self.keys)
else:
self.keys = list(self.h5f.keys())
self.keys.sort()
self.len = len(self.keys)
def __len__(self):
#return len(self.keys)
return self.len
def __getitem__(self, index):
key = str(self.keys[index])
data = np.array(self.h5f[key])
data = torch.Tensor(data)
# the first part of the tuple is the rgb i mage and the second part is hyperspectral
return data[31:34,:,:], data[0:31,:,:]
def close(self):
self.h5f.close()
def shuffle(self):
if 'train' in self.mode:
random.shuffle(self.keys)
note that the data is cropped and saved in one big h5 file offline before the training
The simple training loop
train_dataset = dataset.HyperDataset(mode='train')
test_dataset = dataset.HyperDataset(mode='test')
print("Train_dataset:%d" % (len(train_dataset)))
print("Validation set samples:", len(test_dataset))
# Data Loader (Input Pipeline)
dataloader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
val_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
# ----------------------------------------
# Training
# ----------------------------------------
# Count start time
prev_time = time.time()
logger2 = initialize_logger(log_dir2)
# For loop training
for epoch in range(opt.epochs):
generator.train()
total_loss = utils.AverageMeter()
for i, (img_A, img_B) in enumerate(dataloader):
# this control statement is only to check if the data is changing or not
# it is saved to the log file
if img_B.min()<0 or img_B.max()>1:
print("yes there is problem in labels ")
logger2.info(" IMAGBTRain Epoch [%02d],batch no: %d/%d"
% (epoch,i+1,len(dataloader)))
if img_A.min()<0 or img_A.max()>1:
print("yes there is problem in labels ")
logger2.info("IMAGATRAIN Epoch [%02d],batch no: %d/%d"
% (epoch,i+1,len(dataloader)))
generator.zero_grad()
optimizer_G.zero_grad()
# To device
img_A = img_A.cuda()
img_B = img_B.cuda()
# Train Generator
# Forword propagation
recon_B = generator(img_A)
# # Losses
loss = criterion_L1(recon_B, img_B)
# # Overall Loss and optimize
loss.backward()
optimizer_G.step()
# Determine approximate time left
iters_done = epoch * len(dataloader) + i
iters_left = opt.epochs * len(dataloader) - iters_done
time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time))
prev_time = time.time()
total_loss.update(loss.data)
# Print log
print("\r[Epoch %d/%d] [Batch %d/%d] [Total Loss: %.4f] Time_left: %s" %
((epoch + 1), opt.epochs, i, len(dataloader), total_loss.avg, time_left))
# Save model at certain epochs or iterations
save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator)
# Learning rate decrease at certain epochs
adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G)
#--------------------------------------
# Validation
#--------------------------------------
generator.eval()
losses = utils.AverageMeter()
with torch.no_grad():
for i, data in enumerate(val_loader):
images,labels = data
# this control statement is only to check if the data is changing or not
# it is saved to the log file
if labels.min()<0 or labels.max()>1:
print("yes there is problem in labels ")
logger2.info(" IMAGBVAL Epoch [%02d],batch no: %d/%d"
% (epoch,i+1,len(dataloader)))
if images.min()<0 or images.max()>1:
print("yes there is problem in labels ")
logger2.info("IMAGAVAL Epoch [%02d],batch no: %d/%d"
% (epoch,i+1,len(dataloader)))
images = images.cuda()
labels = labels.cuda()
fake_hyper = generator.forward(images)
loss_v = criterion_valid(fake_hyper, labels)
losses.update(loss_v.data)
print("\r [Total Loss: %.4f]" %
(losses.avg))
Thanks a lot for your help
The code is unfortunately not executable as some definitions are missing (e.g. Conv_ReLU_Block
) as well as the data.
Sorry for the missing definition
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
and the Conv_ReLU_Block
class Conv_ReLU_Block(nn.Module):
def __init__(self, nFeat=64, ksize=3):
super(Conv_ReLU_Block, self).__init__()
self.conv = nn.Conv2d(in_channels=nFeat, out_channels=nFeat, kernel_size=ksize, stride=1, padding=int((ksize - 1) / 2), bias=False)
self.relu = nn.PReLU()
def forward(self, x):
return self.relu(self.conv(x))
and the losses
Loss functions
criterion_L1 = torch.nn.L1Loss().cuda()
criterion_valid = torch.nn.L1Loss().cuda()
regarding the data if you mean the dataset itself, it is quite large and here is kink for it
https://competitions.codalab.org/competitions/22225
and the offline code for creating the dataset
import os
import os.path
import h5py
import cv2
import glob
import numpy as np
import argparse
import hdf5storage
import random
from scipy.io import loadmat
parser = argparse.ArgumentParser(description="SpectralSR")
parser.add_argument("--data_path", type=str, default='../../NTIRE2020', help="data path")
parser.add_argument("--out_data_path", type=str, default='./Dataset', help="out data path")
parser.add_argument("--patch_size", type=int, default=64, help="data patch size")
parser.add_argument("--stride", type=int, default=32, help="data patch stride")
opt = parser.parse_args()
def main():
if not os.path.exists(opt.out_data_path):
os.makedirs(opt.out_data_path)
h5f = h5py.File('./Dataset/train_clean.h5', 'w')
process_data(h5f, patch_size=opt.patch_size, stride=opt.stride, mode='train')
def normalize(data, max_val, min_val):
return (data-min_val)/(max_val-min_val)
def Im2Patch(img, win, stride=1):
k = 0
endc = img.shape[0]
endw = img.shape[1]
endh = img.shape[2]
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
TotalPatNum = patch.shape[1] * patch.shape[2]
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
for i in range(win):
for j in range(win):
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
k = k + 1
return Y.reshape([endc, win, win, TotalPatNum])
def process_data(h5f,patch_size, stride, mode):
if mode == 'train':
print("\nprocess training set ...\n")
patch_num = 1
filenames_hyper =glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Train_Spectral', '*.mat'))
filenames_rgb = glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Train_Clean', '*.png'))
filenames_hyper.sort()
filenames_rgb.sort()
print(len(filenames_rgb),len(filenames_hyper))
print("\nbefore loop ...\n")
#for k in range(1): # make small dataset
for k in range(len(filenames_hyper)):
print([filenames_hyper[k], filenames_rgb[k]])
# load hyperspectral image
#mat = h5py.File(filenames_hyper[k], 'r')
mat = loadmat(filenames_hyper[k])
hyper = np.float32(np.array(mat['cube']))
hyper = np.transpose(hyper, [2, 0, 1])
hyper = normalize(hyper, max_val=1., min_val=0.)
# load rgb image
rgb = cv2.imread(filenames_rgb[k]) # imread -> BGR model
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
rgb = np.transpose(rgb, [2, 0, 1])
rgb = normalize(np.float32(rgb), max_val=255., min_val=0.)
# creat patches
patches_hyper = Im2Patch(hyper, win=patch_size, stride=stride)
patches_rgb = Im2Patch(rgb, win=patch_size, stride=stride)
# add data :重组patches
for j in range(patches_hyper.shape[3]):
print("generate training sample #%d" % patch_num)
sub_hyper = patches_hyper[:, :, :, j]
sub_rgb = patches_rgb[:, :, :, j]
data = np.concatenate((sub_hyper, sub_rgb), 0)
h5f.create_dataset(str(patch_num), data=data)
patch_num += 1
print("\ntraining set: # samples %d\n" % (patch_num-1))
if __name__ == '__main__':
main()
Thanks a lot for your help.
if you will download this dataset from the clean track, make sure that you remove the image number 340 since it has different distribution than the other image in the dataset.
if you need me to post the full-organized code let me know. I am really grateful for your help. Thanks a lot.
Thanks for the link. It seems I might need to register and download the large dataset.
Could you remove the dataset dependency and try to come up with a minimal, executable code snippet which would reproduce the issue?
I doubt the original dataset is needed and would assume that random tensors using the same shape could reproduce the issue (or just a single original sample).
Thanks for your reply. i wrote a script to generate random hyperspectral data and its corresponding rgb image (run this firs to generate the dataset in its folder). The code as well will generate cie_1964_w_gain.npz
file will be used in the loss function. Copy this file in the location of your main file
import numpy as np
import hdf5storage ## you may install it using pip
import cv2 as cv
import argparse
import os
from scipy.io import loadmat
from matplotlib import pyplot as plt
from os.path import basename,join,splitext
import torch
import torch.nn as nn
from PIL import Image
## This script creates a random hyperspectral image with the same width and height as the
## original dataset dimention # 31 channels (482,512,31)
img_width = 482
img_height = 512
img_channel = 31
min_value = 0.0001
filtersPath = "./NTIRE2020/cie_1964_w_gain.npz"
BIT_8 = 256
parser = argparse.ArgumentParser(description="SpectralSR")
parser.add_argument("--root", type=str, default='./NTIRE2020', help="hyper data path")
parser.add_argument("--data_path_hyper", type=str, default='./NTIRE2020_Train_Spectral', help="hyper data path")
parser.add_argument("--data_path_rgb", type=str, default='./NTIRE2020_Train_clean', help="rgb data path")
opt = parser.parse_args()
train_data_path = "./CleanResults/1"
## the path to save dataset
##create file path if not existed for hyper
if not os.path.exists(os.path.join(opt.root, opt.data_path_hyper)):
os.makedirs(os.path.join(opt.root, opt.data_path_hyper))
##create file path if not existed for rgb
if not os.path.exists(os.path.join(opt.root, opt.data_path_rgb)):
os.makedirs(os.path.join(opt.root, opt.data_path_rgb))
## create camera spectral curvers to create RGB images
def create_camera_curves():
filters = np.array(
[[ 0.41817229, 0.04383285, 1.88213184],
[ 1.85425449 , 0.19147895 , 8.52029389],
[ 4.47484197 , 0.46778508 ,21.28163131],
[ 6.88603366 , 0.84577887 ,33.99399574],
[ 8.39714516 , 1.35751927 ,43.04896615],
[ 8.1119695 , 1.95625181 ,43.65117201],
[ 6.61455659 , 2.80353959 ,38.19302491],
[ 4.28065468 , 4.04979288 ,28.83148095],
[ 1.76171245 , 5.54556362 ,16.89601022],
[ 0.35388741 , 7.41626659 , 9.08678754],
[ 0.08350447 ,10.07641566 , 4.78136574],
[ 0.81983625 ,13.26840209 , 2.45180064],
[ 2.57666885 ,16.6583405 , 1.32846351],
[ 5.17506725 ,19.13938808 , 0.66634341],
[ 8.24479763 ,21.03705468 , 0.2992648 ],
[11.59403605 ,21.68813996 , 0.08726733],
[15.43222205 ,21.81014328 , 0. ],
[19.22736473 ,20.89631022 , 0. ],
[22.19258323 ,19.00212068 , 0. ],
[24.47626429 ,17.00053586 , 0. ],
[24.59596279 ,14.39680704 , 0. ],
[22.54970928 ,11.54566013 , 0. ],
[18.73811079 ,8.70483506 , 0. ],
[14.16834157 ,6.19951365 , 0. ],
[ 9.44386149 ,3.93253499 , 0. ],
[ 5.87176941 ,2.35375213 , 0. ],
[ 3.33860341 ,1.31824378 , 0. ],
[ 1.77820327 ,0.6954211 , 0. ],
[ 0.89392678 ,0.3478177 , 0. ],
[ 0.43636996 ,0.16945318 , 0. ],
[ 0.20956822 ,0.0813007 , 0. ]])
bands = np.array([[400 ,410, 420, 430, 440, 450 ,460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570,
580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700]])
#np.savez(filtersPath, filters=filters, bands=bands)
np.savez("./NTIRE2020/cie_1964_w_gain.npz", filters=filters,bands=bands)
create_camera_curves()
def create_rgb(hyper_img, filtersPath):
model_hs2rgb = nn.Conv2d(31, 3, 1, bias=False)
cie_matrix = np.load(filtersPath)['filters']
cie_matrix = torch.from_numpy(np.transpose(cie_matrix, [1, 0])).unsqueeze(-1).unsqueeze(-1).float()
model_hs2rgb.weight.data = cie_matrix
with torch.no_grad():
hyper_tensor = torch.tensor(np.transpose(hyper_img, [2, 0, 1]),dtype=torch.float32)
C,W,H = hyper_tensor.shape
rgb_tensor = model_hs2rgb(hyper_tensor.view(1,C,W,H))
rgb_tensor = rgb_tensor / 255
rgb_tensor = torch.clamp(rgb_tensor, 0, 1) * 255
rgb_tensor = rgb_tensor.squeeze(0)
rgb_img = rgb_tensor.numpy()
rgb_img = np.transpose(rgb_img,[1,2,0])
return rgb_img
## the main loop to create the dataset
for i in range(1,450):
print("creating dataset No.",i)
hyper_img = np.random.rand(img_width,img_height,img_channel)
## this step to make the min value not less than 0.0001 to make the train stable
hyper_img[hyper_img < min_value] = min_value
train_data_path_hyper = os.path.join(opt.root, opt.data_path_hyper, 'train'+str(i)+'.mat')
train_data_path_rgb = os.path.join(opt.root, opt.data_path_rgb)
hdf5storage.savemat(train_data_path_hyper, {'cube': hyper_img}, format='5')
# Project image to RGB
rgbIm = np.true_divide(create_rgb(hyper_img, filtersPath), BIT_8)
# Save image file
# save RGB image
cv.imwrite(os.path.join(train_data_path_rgb ,'train'+str(i)+'.png'), (rgbIm * 255).astype(np.uint8))
Run the next script to generate the image batches offline at first. Adjust the paths to the NTIRE2020 folder that holds the data. The file name istrain_data_proprocess
import os
import os.path
import h5py
import cv2
import glob
import numpy as np
import argparse
import hdf5storage
import random
from scipy.io import loadmat
parser = argparse.ArgumentParser(description="SpectralSR")
parser.add_argument("--data_path", type=str, default='../../../NTIRE2020', help="data path")
parser.add_argument("--out_data_path", type=str, default='./Dataset', help="out data path")
parser.add_argument("--patch_size", type=int, default=64, help="data patch size")
parser.add_argument("--stride", type=int, default=32, help="data patch stride")
opt = parser.parse_args()
def main():
if not os.path.exists(opt.out_data_path):
os.makedirs(opt.out_data_path)
h5f = h5py.File('./Dataset/train_clean.h5', 'w')
process_data(h5f, patch_size=opt.patch_size, stride=opt.stride, mode='train')
def normalize(data, max_val, min_val):
return (data-min_val)/(max_val-min_val)
def Im2Patch(img, win, stride=1):
k = 0
endc = img.shape[0]
endw = img.shape[1]
endh = img.shape[2]
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
TotalPatNum = patch.shape[1] * patch.shape[2]
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
for i in range(win):
for j in range(win):
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
k = k + 1
return Y.reshape([endc, win, win, TotalPatNum])
def process_data(h5f,patch_size, stride, mode):
if mode == 'train':
print("\nprocess training set ...\n")
patch_num = 1
filenames_hyper =glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Train_Spectral', '*.mat'))
filenames_rgb = glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Train_Clean', '*.png'))
filenames_hyper.sort()
filenames_rgb.sort()
print(len(filenames_rgb),len(filenames_hyper))
print("\nbefore loop ...\n")
#for k in range(1): # make small dataset
for k in range(len(filenames_hyper)):
print([filenames_hyper[k], filenames_rgb[k]])
# load hyperspectral image
#mat = h5py.File(filenames_hyper[k], 'r')
mat = loadmat(filenames_hyper[k])
hyper = np.float32(np.array(mat['cube']))
hyper = np.transpose(hyper, [2, 0, 1])
hyper = normalize(hyper, max_val=1., min_val=0.)
# load rgb image
rgb = cv2.imread(filenames_rgb[k]) # imread -> BGR model
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
rgb = np.transpose(rgb, [2, 0, 1])
rgb = normalize(np.float32(rgb), max_val=255., min_val=0.)
# creat patches
patches_hyper = Im2Patch(hyper, win=patch_size, stride=stride)
patches_rgb = Im2Patch(rgb, win=patch_size, stride=stride)
# add data :重组patches
for j in range(patches_hyper.shape[3]):
print("generate training sample #%d" % patch_num)
sub_hyper = patches_hyper[:, :, :, j]
sub_rgb = patches_rgb[:, :, :, j]
data = np.concatenate((sub_hyper, sub_rgb), 0)
h5f.create_dataset(str(patch_num), data=data)
patch_num += 1
print("\ntraining set: # samples %d\n" % (patch_num-1))
if __name__ == '__main__':
main()
Copy ten images and creates the two folders named as NTIRE2020_Validation_Spectral
and NTIRE2020_Validation_Clean
for the validation.The script name is valid_data_preprocess
import os
import os.path
import h5py
from scipy.io import loadmat,savemat
import cv2
import glob
import numpy as np
import argparse
import hdf5storage
parser = argparse.ArgumentParser(description="SpectralSR")
parser.add_argument("--data_path", type=str, default='../../../NTIRE2020', help="data path")
#parser.add_argument("--data_path", type=str, default='./NTIRE2020', help="data path")
parser.add_argument("--patch_size", type=int, default=64, help="data patch size")
parser.add_argument("--stride", type=int, default=32, help="data patch stride")
parser.add_argument("--out_data_path", type=str, default='./Dataset', help="out data path")
opt = parser.parse_args()
def main():
if not os.path.exists(opt.out_data_path):
os.makedirs(opt.out_data_path)
h5f = h5py.File('./Dataset/test_final.h5', 'w')
process_data(h5f,patch_size=opt.patch_size, stride=opt.stride, mode='valid')
def normalize(data, max_val, min_val):
return (data-min_val)/(max_val-min_val)
def Im2Patch(img, win, stride=1):
k = 0
endc = img.shape[0]
endw = img.shape[1]
endh = img.shape[2]
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
TotalPatNum = patch.shape[1] * patch.shape[2]
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
for i in range(win):
for j in range(win):
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
k = k + 1
return Y.reshape([endc, win, win, TotalPatNum])
def process_data(h5f,patch_size, stride, mode):
if mode == 'valid':
print("\nprocess valid set ...\n")
patch_num = 1
filenames_hyper = glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Validation_Spectral', '*.mat'))
filenames_rgb = glob.glob(os.path.join(opt.data_path, 'NTIRE2020_Validation_Clean', '*.png'))
filenames_hyper.sort()
filenames_rgb.sort()
#for k in range(1): # make small dataset
for k in range(len(filenames_hyper)):
# continue
print([filenames_hyper[k], filenames_rgb[k]])
# load hyperspectral image
mat = hdf5storage.loadmat(filenames_hyper[k])
#mat = h5py.File(filenames_hyper[k], 'r')
hyper = np.float32(np.array(mat['cube']))
hyper = np.transpose(hyper, [2, 0, 1])
hyper = normalize(hyper, max_val=1., min_val=0.)
# load rgb image
rgb = cv2.imread(filenames_rgb[k]) # imread -> BGR model
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
rgb = np.transpose(rgb, [2, 0, 1])
rgb = normalize(np.float32(rgb), max_val=255., min_val=0.)
# creat patches
patches_hyper = Im2Patch(hyper, win=patch_size, stride=stride)
patches_rgb = Im2Patch(rgb, win=patch_size, stride=stride)
# add data :重组patches
for j in range(patches_hyper.shape[3]):
print("generate valid sample #%d" % patch_num)
sub_hyper = patches_hyper[:, :, :, j]
sub_rgb = patches_rgb[:, :, :, j]
data = np.concatenate((sub_hyper, sub_rgb), 0)
h5f.create_dataset(str(patch_num), data=data)
patch_num += 1
print("\ntraining set: # samples %d\n" % (patch_num-1))
if __name__ == '__main__':
main()
The dataset file class interface is called dataset.py
is as folllow: as asimple hack to make my code run is to not load the entire dataset batches at once. to do this remove the comment on self.len=32000
and comment the line below
import os
import random
import h5py
import numpy as np
import torch
import torch.utils.data as udata
class HyperDataset(udata.Dataset):
def __init__(self, mode='train'):
self.mode = mode
if self.mode == 'train':
self.h5f = h5py.File('./Dataset/train_clean.h5', 'r')
elif self.mode == 'test':
self.h5f = h5py.File('./Dataset/test_final.h5', 'r')
#self.keys = list(self.h5f.keys())
if 'train' in self.mode:
self.keys = list(self.h5f.keys())
random.shuffle(self.keys)
#self.len = 32000
self.len = len(self.keys)
else:
self.keys = list(self.h5f.keys())
self.keys.sort()
self.len = len(self.keys)
def __len__(self):
#return len(self.keys)
return self.len
def __getitem__(self, index):
key = str(self.keys[index])
data = np.array(self.h5f[key])
data = torch.Tensor(data)
return data[31:34,:,:], data[0:31,:,:]
def close(self):
self.h5f.close()
def shuffle(self):
if 'train' in self.mode:
random.shuffle(self.keys)
The main file is called main.py
import torch
import torch.nn as nn
import argparse
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
#from torch.autograd import Variable
import os
import time
#import random
#from dataset import HyperDatasetValid, HyperDatasetTrain1, HyperDatasetTrain2, HyperDatasetTrain3, HyperDatasetTrain4 # Clean Data set
import dataset # Clean Data set
from AWAN import AWAN
from utils import AverageMeter, initialize_logger, save_checkpoint, record_loss, LossTrainCSS, Loss_valid
import visdom
from train import train,test
from torch_poly_lr_decay import PolynomialLRDecay
#from BackBone import BackBone
#from RSCAN import BackBone,SpatialSpectralSRNet
#from FMNet import FMNet
#from collections import OrderedDict
#from proposed import SpatialSpectralSRNet
#from SAN import SAN
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
parser = argparse.ArgumentParser(description="SSR")
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
parser.add_argument("--end_epoch", type=int, default=150+1, help="number of epochs")
parser.add_argument("--init_lr", type=float, default=1e-4, help="initial learning rate")
parser.add_argument("--decay_power", type=float, default=1.5, help="decay power")
parser.add_argument("--trade_off", type=float, default=10, help="trade_off")
parser.add_argument("--max_iter", type=float, default=3000000, help="max_iter") # patch48:380x450/32x100-534375; patch96:82x450/32x100-113906
parser.add_argument("--outf", type=str, default="CleanResults", help='path log files')
parser.add_argument('--b1', type = float, default = 0.9, help = 'Adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type = float, default = 0.999, help = 'Adam: decay of second order momentum of gradient')
parser.add_argument('--weight_decay', type = float, default = 0, help = 'weight decay for optimizer')
parser.add_argument("--milestones", type=list, default=list(range(30, 150, 30)), help="how many epoch to reduce the lr")
parser.add_argument("--gamma", type=int, default=0.5, help="how much to reduce the lr each time")
opt = parser.parse_args()
def main():
cudnn.benchmark = True
# load dataset
print("\nloading dataset ...")
print("\nloading dataset ...")
train_dataset = dataset.HyperDataset(mode='train')
test_dataset = dataset.HyperDataset(mode='test')
print("Train_dataset:%d" % (len(train_dataset)))
print("Validation set samples:", len(test_dataset))
# Data Loader (Input Pipeline)
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
val_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
#torch.autograd.set_detect_anomaly(True)
viz = visdom.Visdom(env="proposed-model channel6")
if not viz.check_connection():
print("Visdom is not connected. Did you run 'python -m visdom.server' ?")
# model
print("\nbuilding models_baseline ...")
model = AWAN(3, 31, 200, 8)
#model = SAN(3,128,31,6,3)
#model = BackBone(3,31,128,6)
#model = BackBone(3,31,128,5,8)
#model = FMNet(bNum=3, nblocks=5, input_channels=31, num_features=64, out_channels=31)
#model = SpatialSpectralSRNet(in_channels=3, out_channels=31, n_channels=64, n_blocks=7, kernel_size=3, upscale_factor=2)
print('Parameters number is ', sum(param.numel() for param in model.parameters()))
criterion_train = LossTrainCSS()
criterion_train_L1 = torch.nn.L1Loss().cuda()
criterion_valid = Loss_valid()
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model) # batchsize integer times
if torch.cuda.is_available():
model.cuda()
criterion_train.cuda()
criterion_valid.cuda()
# Parameters, Loss and Optimizer
start_epoch = 0
iteration = 0
record_val_loss = 1000
#optimizer = optim.Adam(model.parameters(), lr=opt.init_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
optimizer = optim.Adam(model.parameters(), lr = opt.init_lr, betas = (opt.b1, opt.b2), weight_decay = opt.weight_decay)
#lr_scheduler = PolynomialLRDecay(optimizer, max_decay_steps=opt.max_iter, end_learning_rate=0.0000001, power=1.5)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, opt.milestones, opt.gamma)
# visualzation
if not os.path.exists(opt.outf):
os.makedirs(opt.outf)
loss_csv = open(os.path.join(opt.outf, 'loss.csv'), 'a+')
log_dir = os.path.join(opt.outf, 'train.log')
logger = initialize_logger(log_dir)
# Resume
resume_file = opt.outf + '/best_net_7epoch.pth'
#resume_file = ''
if resume_file:
if os.path.isfile(resume_file):
print("=> loading checkpoint '{}'".format(resume_file))
checkpoint = torch.load(resume_file)
start_epoch = checkpoint['epoch']
iteration = checkpoint['iter']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
# start epoch
for epoch in range(start_epoch, opt.end_epoch):
start_time = time.time()
total_loss, train_loss, rgb_train_loss,iteration = train(train_loader, model, criterion_train,criterion_train_L1, optimizer, epoch, lr_scheduler, opt)
lr_scheduler.step()
val_loss = test(model, val_loader, criterion_valid)
# Save model either the best model so far or every 10 epochs
if torch.abs(val_loss - record_val_loss) < 0.0001 or val_loss < record_val_loss:
save_checkpoint(opt.outf, epoch, iteration, model, optimizer,best=True)
if val_loss < record_val_loss:
record_val_loss = val_loss
if epoch %5==0:
save_checkpoint(opt.outf, epoch, iteration, model, optimizer,best=False)
# print loss
end_time = time.time()
epoch_time = end_time - start_time
print("Epoch [%02d], Iter[%06d], Time:%.9f, Train Loss: %.9f Test Loss: %.9f "
% (epoch, iteration, epoch_time, train_loss, val_loss))
#for learning rate
viz.line([optimizer.param_groups[0]['lr']*(10**4)],[epoch],win='Learning rate schedule',update='append',
opts=dict(title='Learning rate schedule',
legend=['lr*(10^4)']))
#for HSI train loss
viz.line([train_loss.detach().cpu()],[epoch],win='HSI Train_loss',
update='append',opts=dict(title=' HSI Train Learning Curve.',
legend=['HSI Train Loss']))
#for validation loss
viz.line([val_loss.detach().cpu()],[epoch],win='Val Train_loss',
update='append',opts=dict(title='val loss Learning Curve.',
legend=['val loss']))
#for rgb train loss
viz.line([rgb_train_loss.detach().cpu()],[epoch],win=' RGB Train_loss',
update='append',opts=dict(title='rgb train loss Learning Curve.',
legend=['rgb train loss']))
#for total train loss
viz.line([total_loss.detach().cpu()],[epoch],win='Total Train_loss',
update='append',opts=dict(title='total train loss Learning Curve.',
legend=['total train loss']))
# for train_loss and validation_loss
viz.line([[train_loss.detach().cpu(),val_loss.detach().cpu()]],[epoch],win='Train_loss and val_loss',
update='append',opts=dict(title='Learning Curve.',
legend=['Train Loss', 'Validation Loss']))
# save loss
record_loss(loss_csv,epoch, train_loss, val_loss)
logger.info("Epoch [%02d], Train Loss: %.9f Test Loss: %.9f "
% (epoch, train_loss, val_loss))
if __name__ == '__main__':
main()
print(torch.__version__)
The train.py
holds the train and test functions
import torch
from utils import AverageMeter, initialize_logger, save_checkpoint, record_loss, LossTrainCSS, Loss_valid
import datetime
import os
import time
import random
log_dir2 = os.path.join("CleanResults", 'error.log')
logger2 = initialize_logger(log_dir2)
def train(train_loader, model, criterion_train,criterion_train_L1, optimizer, epoch, lr_scheduler, opt):
total_loss = AverageMeter()
losses = AverageMeter()
losses_rgb = AverageMeter()
#random.shuffle(train_loader)
prev_time = time.time()
model.train()
for i,data in enumerate(train_loader):
#with torch.autograd.set_detect_anomaly(True):
images, labels = data
# to only test the labels that having values other than 0 and 1
if labels.min()<0 or labels.max()>1:
print("yes there is problem in labels ")
logger2.info(" IMAGBTRain Epoch [%02d],batch no: %d/%d"
% (epoch,i+1,len(train_loader)))
if images.min()<0 or images.max()>1:
print("yes there is problem in labels ")
logger2.info("IMAGATRAIN Epoch [%02d],batch no: %d/%d"
% (epoch,i+1,len(train_loader)))
images, labels = images.cuda(), labels.cuda()
model.zero_grad()
optimizer.zero_grad()
# #lr_scheduler.step()
fake_hyper = model.forward(images)
#loss = criterion_train_L1(fake_hyper, labels)
loss , loss_rgb = criterion_train(fake_hyper, labels, images)
loss_all = loss + opt.trade_off * loss_rgb
loss_all.backward()
optimizer.step()
# # Determine approximate time left
iters_done = epoch *len(train_loader) + i
iters_left =opt.end_epoch*len(train_loader) - iters_done
time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time))
prev_time = time.time()
#lr_scheduler.step()
print('[Epoch:%02d],[Batch NO:%d/%d],[iter:%d],[Time_left=%s]'
% (epoch, i+1, len(train_loader), iters_done, time_left))
## record loss
losses.update(loss.data)
losses_rgb.update(loss_rgb.data)
total_loss.update(loss_all.data)
print('[Epoch:%02d],[Batch NO:%d/%d],[iter:%d],[Time_left=%s],[train_losses.avg=%.9f], [rgb_train_losses.avg=%.9f]'
% (epoch, i+1, len(train_loader), iters_done, time_left,losses.avg, losses_rgb.avg))
return total_loss.avg, losses.avg,losses_rgb.avg ,iters_done
def test(model, test_dataset, criterion):
model.eval()
losses = AverageMeter()
for i, data in enumerate(test_dataset):
images,labels = data
if labels.min()<0 or labels.max()>1:
print("yes there is problem in labels ")
logger2.info(" IMAGBVAL")
if images.min()<0 or images.max()>1:
print("yes there is problem in labels ")
logger2.info("IMAGAVAL")
images, labels = images.cuda(), labels.cuda()
with torch.no_grad():
fake_hyper = model.forward(images)
loss = criterion(fake_hyper, labels)
losses.update(loss.data)
return losses.avg
# Learning rate
def poly_lr_scheduler(optimizer, init_lr, iteraion, lr_decay_iter=1, max_iter=100, power=0.9):
"""Polynomial decay of learning rate
:param init_lr is base learning rate
:param iter is a current iteration
:param lr_decay_iter how frequently decay occurs, default is 1
:param max_iter is number of maximum iterations
:param power is a polymomial power
"""
if iteraion % lr_decay_iter or iteraion > max_iter:
return optimizer
lr = init_lr*(1 - iteraion/max_iter)**power
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
The utils.py
from __future__ import division
import torch
import torch.nn as nn
import logging
import numpy as np
import os
import hdf5storage
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def initialize_logger(file_dir):
logger = logging.getLogger()
fhandler = logging.FileHandler(filename=file_dir, mode='a')
formatter = logging.Formatter('%(asctime)s - %(message)s',"%Y-%m-%d %H:%M:%S")
fhandler.setFormatter(formatter)
logger.addHandler(fhandler)
logger.setLevel(logging.INFO)
return logger
def save_checkpoint(model_path, epoch, iteration, model, optimizer,best=True):
state = {
'epoch': epoch,
'iter': iteration,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
if best == False:
torch.save(state, os.path.join(model_path, 'net_%depoch.pth' % epoch))
else:
torch.save(state, os.path.join(model_path, 'best_net_%depoch.pth' % epoch))
def save_matv73(mat_name, var_name, var):
hdf5storage.savemat(mat_name, {var_name: var}, format='7.3', store_python_metadata=True)
def record_loss(loss_csv,epoch, train_loss, test_loss):
""" Record many results."""
loss_csv.write('{},{},{}\n'.format(epoch, train_loss, test_loss))
loss_csv.flush()
loss_csv.close
class Loss_train(nn.Module):
def __init__(self):
super(Loss_train, self).__init__()
def forward(self, outputs, label):
error = torch.abs(outputs - label) / label
# error = torch.abs(outputs - label)
rrmse = torch.mean(error.view(-1))
return rrmse
class Loss_valid(nn.Module):
def __init__(self):
super(Loss_valid, self).__init__()
def forward(self, outputs, label):
outputs1 = outputs.clone()
label1 = label.clone()
error = torch.abs(outputs1 - label1) / label1
# error = torch.abs(outputs - label)
rrmse = torch.mean(error.view(-1))
return rrmse
class LossTrainCSS(nn.Module):
def __init__(self):
super(LossTrainCSS, self).__init__()
self.model_hs2rgb = nn.Conv2d(31, 3, 1, bias=False)
filtersPath = './cie_1964_w_gain.npz'
cie_matrix = np.load(filtersPath)['filters']
cie_matrix = torch.from_numpy(np.transpose(cie_matrix, [1, 0])).unsqueeze(-1).unsqueeze(-1).float()
self.model_hs2rgb.weight.data = cie_matrix
def forward(self, outputs, label, rgb_label):
rrmse = self.mrae_loss(outputs, label)
# hs2rgb
with torch.no_grad():
rgb_tensor = self.model_hs2rgb(outputs)
rgb_tensor = rgb_tensor / 255
rgb_tensor = torch.clamp(rgb_tensor, 0, 1) * 255
# rgb_tensor = torch.tensor(rgb_tensor, dtype=torch.uint8)
# rgb_tensor = torch.tensor(rgb_tensor, dtype=torch.uint8)
# update from torch it self is the line below , the original line is below
# the written one
rgb_tensor = rgb_tensor.clone().detach().byte().float()
#rgb_tensor = torch.tensor(rgb_tensor).byte().float()
rgb_tensor = rgb_tensor / 255
rrmse_rgb = self.rgb_mrae_loss(rgb_tensor, rgb_label)
return rrmse, rrmse_rgb
def mrae_loss(self, outputs, label):
error = torch.abs(outputs - label) / label
mrae = torch.mean(error.view(-1))
return mrae
def rgb_mrae_loss(self, outputs, label):
outputs1 = outputs.clone()
label1 = label.clone()
error = torch.abs(outputs1 - label1)
mrae = torch.mean(error.view(-1))
return mrae
if you need script to check that the dataset generated after the cropping not changed
h5f = h5py.File('./Dataset/train_clean.h5', 'r')
for i,key in enumerate(h5f.keys()):
data = np.array(h5f[key])
if i%1000 == 0:
print(i)
if data[0:31,:,:].min() <0 or data[0:31,:,:].max()>1:
print("yes there are dude")
print(key, "with min: {0} and max:{1}".format(data[0:31,:,:].min(),data[0:31,:,:].max()))
if torch.isnan(torch.tensor(data[0:31,:,:])).sum() > 0:
print("yes damn nan")
if torch.isinf(torch.tensor(data[0:31,:,:])).sum() > 0:
print("yes damn inf")
Note that when i didnot move he labels to cuda nothing happenand if i move it, the labels changed even if i removed the loss and backward step. The gpu i am using is RTX 2080 Ti
and torch version '1.10.0'
The code for the model in file called AWAN.py
and you can replace it with any other model all yeild the same behavior
import torch
from torch import nn
from torch.nn import functional as F
class AWCA(nn.Module):
def __init__(self, channel, reduction=16):
super(AWCA, self).__init__()
self.conv = nn.Conv2d(channel, 1, 1, bias=False)
self.softmax = nn.Softmax(dim=2)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.PReLU(),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, h, w = x.size()
input_x = x
input_x = input_x.view(b, c, h*w).unsqueeze(1)
mask = self.conv(x).view(b, 1, h*w)
mask = self.softmax(mask).unsqueeze(-1)
y = torch.matmul(input_x, mask).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class NONLocalBlock2D(nn.Module):
def __init__(self, in_channels, reduction=8, dimension=2, sub_sample=False, bn_layer=False):
super(NONLocalBlock2D, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = self.in_channels // reduction
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0, bias=False)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0, bias=False),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0, bias=False)
nn.init.constant_(self.W.weight, 0)
# nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0, bias=False)
# self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
# kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x):
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
# phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
# f = torch.matmul(theta_x, phi_x)
f = self.count_cov_second(theta_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def count_cov_second(self, input):
x = input
batchSize, dim, M = x.data.shape
x_mean_band = x.mean(2).view(batchSize, dim, 1).expand(batchSize, dim, M)
y = (x - x_mean_band).bmm(x.transpose(1, 2)) / M
return y
class PSNL(nn.Module):
def __init__(self, channels):
super(PSNL, self).__init__()
# nonlocal module
self.non_local = NONLocalBlock2D(channels)
def forward(self,x):
# divide feature map into 4 part
batch_size, C, H, W = x.shape
H1 = int(H / 2)
W1 = int(W / 2)
nonlocal_feat = torch.zeros_like(x)
feat_sub_lu = x[:, :, :H1, :W1]
feat_sub_ld = x[:, :, H1:, :W1]
feat_sub_ru = x[:, :, :H1, W1:]
feat_sub_rd = x[:, :, H1:, W1:]
nonlocal_lu = self.non_local(feat_sub_lu)
nonlocal_ld = self.non_local(feat_sub_ld)
nonlocal_ru = self.non_local(feat_sub_ru)
nonlocal_rd = self.non_local(feat_sub_rd)
nonlocal_feat[:, :, :H1, :W1] = nonlocal_lu
nonlocal_feat[:, :, H1:, :W1] = nonlocal_ld
nonlocal_feat[:, :, :H1, W1:] = nonlocal_ru
nonlocal_feat[:, :, H1:, W1:] = nonlocal_rd
return nonlocal_feat
class Conv3x3(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, dilation=1):
super(Conv3x3, self).__init__()
reflect_padding = int(dilation * (kernel_size - 1) / 2)
self.reflection_pad = nn.ReflectionPad2d(reflect_padding)
self.conv2d = nn.Conv2d(in_dim, out_dim, kernel_size, stride, dilation=dilation, bias=False)
def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class DRAB(nn.Module):
def __init__(self, in_dim, out_dim, res_dim, k1_size=3, k2_size=1, dilation=1):
super(DRAB, self).__init__()
self.conv1 = Conv3x3(in_dim, in_dim, 3, 1)
self.relu1 = nn.PReLU()
self.conv2 = Conv3x3(in_dim, in_dim, 3, 1)
self.relu2 = nn.PReLU()
# T^{l}_{1}: (conv.)
self.up_conv = Conv3x3(in_dim, res_dim, kernel_size=k1_size, stride=1, dilation=dilation)
self.up_relu = nn.PReLU()
self.se = AWCA(res_dim)
# T^{l}_{2}: (conv.)
self.down_conv = Conv3x3(res_dim, out_dim, kernel_size=k2_size, stride=1)
self.down_relu = nn.PReLU()
def forward(self, x, res):
x_r = x
out = self.relu1(self.conv1(x))
out = self.conv2(out)
out = out + x_r
out = self.relu2(out)
# T^{l}_{1}
out = self.up_conv(out)
out = out + res
out = self.up_relu(out)
res = out
out = self.se(out)
# T^{l}_{2}
out = self.down_conv(out)
out = out + x_r
out = self.down_relu(out)
return out, res
class AWAN(nn.Module):
def __init__(self, inplanes=3, planes=31, channels=200, n_DRBs=8):
super(AWAN, self).__init__()
# 2D Nets
self.input_conv2D = Conv3x3(inplanes, channels, 3, 1)
self.input_prelu2D = nn.PReLU()
self.head_conv2D = Conv3x3(channels, channels, 3, 1)
self.backbone = nn.ModuleList(
[DRAB(in_dim=channels, out_dim=channels, res_dim=channels, k1_size=5, k2_size=3, dilation=1) for _ in
range(n_DRBs)])
self.tail_conv2D = Conv3x3(channels, channels, 3, 1)
self.output_prelu2D = nn.PReLU()
self.output_conv2D = Conv3x3(channels, planes, 3, 1)
self.tail_nonlocal = PSNL(planes)
def forward(self, x):
out = self.DRN2D(x)
return out
def DRN2D(self, x):
out = self.input_prelu2D(self.input_conv2D(x))
out = self.head_conv2D(out)
residual = out
res = out
for i, block in enumerate(self.backbone):
out, res = block(out, res)
out = self.tail_conv2D(out)
out = torch.add(out, residual)
out = self.output_conv2D(self.output_prelu2D(out))
out = self.tail_nonlocal(out)
return out
if __name__ == "__main__":
# import os
# os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
input_tensor = torch.rand(1, 3, 64, 64)
model = AWAN(3, 31, 200, 10)
# model = nn.DataParallel(model).cuda()
with torch.no_grad():
output_tensor = model(input_tensor)
print(output_tensor.size())
print('Parameters number is ', sum(param.numel() for param in model.parameters()))
print(torch.__version__)
Sorry for the long scripts and thanks a lot.