I am working on a problem of spectral super-resolution where the inputs to the models are both rgb image with 3 channel (input image) and hyperspectral image with 31 channel (the labels to compare the output with).
At the training phase the pixel values of the labels changes without any reason.
Please any help regarding this issue.
Can you provide some reproducible code where this is happening, maybe in a Google Colab? Your pixels shouldn’t be changing mid-training sounds like you have some in-place operation on the Tensors somewhere.
Thanks alot for your reply, here is the code of the model
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
class Flatten(nn.Module):
def forward(self, x):
input_x = x
return input_x.view(x.size(0), -1)
## for channel descriptors
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max','adaptive','second']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
if "adaptive" in pool_types:
self.conv = nn.Conv2d(gate_channels, 1, 1, bias=False)
self.softmax = nn.Softmax(dim=2)
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.PReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type=='avg':
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( avg_pool )
elif pool_type=='max':
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( max_pool )
elif pool_type=='lp':
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( lp_pool )
elif pool_type=='lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp( lse_pool )
elif pool_type == 'adaptive':
b, c, h, w = x.size()
input_x = x
input_x = input_x.view(b, c, h*w).unsqueeze(1)
mask = self.conv(x).view(b, 1, h*w)
mask = self.softmax(mask).unsqueeze(-1)
y = torch.matmul(input_x, mask).view(b, c,1,1)
channel_att_raw = self.mlp(y)
elif pool_type == 'second':
b, c, h, w = x.size()
input_x = x
input_x = input_x.view(b, c, h*w)
y = self.count_cov_second(input_x)
cov_mat_sum = torch.mean(y,1)
cov_mat_sum = cov_mat_sum.view(b,c,1,1)
channel_att_raw = self.mlp(cov_mat_sum)
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
def count_cov_second(self, input_x):
x = input_x
batchSize, dim, M = x.data.shape
x_mean_band = x.mean(2).view(batchSize, dim, 1).expand(batchSize, dim, M)
y = (x - x_mean_band).bmm(x.transpose(1, 2)) / M
return y
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
## for spatial map
class ChannelPool(nn.Module):
def forward(self, x):
input_x = x
return torch.cat( (torch.max(input_x,1)[0].unsqueeze(1), torch.mean(input_x,1).unsqueeze(1)), dim=1 )
## for proposed spatial map
class CBranch(nn.Module):
def __init__(self,in_channels,kernel_size1,padding1,kernel_size2,padding2):
super(CBranch,self).__init__()
self.conv3x1 = BasicConv(in_channels,in_channels//16,kernel_size=kernel_size1,padding=padding1)
self.conv1x3 = BasicConv(in_channels//16,1,kernel_size=kernel_size2,padding=padding2)
def forward(self,x):
input_x = x
out = self.conv3x1(input_x)
out = self.conv1x3(out)
return out
class SpatialContext(nn.Module):
def __init__(self,in_channels):
super(SpatialContext,self).__init__()
self.C1Branch = CBranch(in_channels,kernel_size1 = (3,1),padding1=(1,0),kernel_size2 = (1,3),padding2=(0,1))
self.C2Branch = CBranch(in_channels,kernel_size1 = (1,3),padding1=(0,1),kernel_size2 = (3,1),padding2=(1,0))
self.C3Branch = CBranch(in_channels,kernel_size1 = 3,padding1=1,kernel_size2 = 3,padding2=1)
def forward(self,x):
c1_input = c2_input = c3_input = x
c1_out = self.C1Branch(c1_input)
c2_out = self.C1Branch(c2_input)
c3_out = self.C1Branch(c3_input)
out = c1_out + c2_out + c3_out
return out
class ChannelPoolMix(nn.Module):
def __init__(self,gate_channel,reduction_ratio=16, dilation_conv_num=2, dilation_val=4,pool_types = ['avg','max','spatial','Context']):
super(ChannelPoolMix,self).__init__()
self.pool_types = pool_types
if 'spatial' in self.pool_types:
self.gate_s = nn.Sequential()
self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1))
#self.gate_s.add_module( 'gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel//reduction_ratio) )
self.gate_s.add_module( 'gate_s_relu_reduce0',nn.PReLU() )
for i in range( dilation_conv_num ):
self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \
padding=dilation_val, dilation=dilation_val) )
#self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) )
self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.PReLU() )
self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) )
if 'context' in self.pool_types:
self.spatial_context = SpatialContext(gate_channel)
def forward(self, x):
spatial_att_sum = None
for pool_type in self.pool_types:
if pool_type == 'avg':
out = torch.mean(x,1).unsqueeze(1)
elif pool_type == 'max':
out = torch.max(x,1)[0].unsqueeze(1)
elif pool_type == 'spatial':
out = self.gate_s( x )
elif pool_type == 'context':
out = self.spatial_context(x)
if spatial_att_sum == None:
spatial_att_sum = out
else:
spatial_att_sum = torch.cat((spatial_att_sum,out),dim=1)
return spatial_att_sum
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False,bn=False)
def forward(self, x):
input_x = x
x_compress = self.compress(input_x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid(x_out) # broadcasting
return input_x * scale
## proposed spatial gate
class SpatialGateMix(nn.Module):
def __init__(self,gate_channel,reduction_ratio=16, dilation_conv_num=2, dilation_val=4,pool_types = ['avg','max']):
super(SpatialGateMix, self).__init__()
kernel_size = 7
self.compress = ChannelPoolMix(gate_channel = gate_channel,reduction_ratio=reduction_ratio, dilation_conv_num=dilation_conv_num, dilation_val=dilation_val,pool_types = pool_types)
self.spatial = BasicConv(len(pool_types), 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=True,bn=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid(x_out) # broadcasting
return x * scale
class Modified_CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types_channel=['avg', 'max','adaptive','second'],pool_types_spatial = ['avg','max','spatial','context'], no_spatial=False):
super(Modified_CBAM, self).__init__()
self.input_conv = BasicConv(gate_channels,gate_channels,kernel_size=3,padding=1,dilation = 1)
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types_channel)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGateMix(gate_channels,pool_types = pool_types_spatial)
def forward(self, x):
input_x = x
resduial = x
x_conv = self.input_conv(input_x)
x_out = self.ChannelGate(x_conv)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
out = x_out + resduial
return out
class Conv3x3(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, dilation=1):
super(Conv3x3, self).__init__()
reflect_padding = int(dilation * (kernel_size - 1) / 2)
self.reflection_pad = nn.ReflectionPad2d(reflect_padding)
self.conv2d = nn.Conv2d(in_dim, out_dim, kernel_size, stride, dilation=dilation, bias=False)
def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, relu=True, bn=False, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.PReLU() if relu else None
def forward(self, x):
input_x = x
input_x = self.conv(input_x)
if self.bn is not None:
input_x = self.bn(input_x)
if self.relu is not None:
input_x = self.relu(input_x)
return input_x
class ResduialBlock(nn.Module):
def __init__(self,in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, relu=True, bn=False, bias=False):
super(ResduialBlock,self).__init__()
self.conv = BasicConv(in_planes = in_planes, out_planes=out_planes, kernel_size=kernel_size,\
stride=stride, padding=padding, dilation=dilation, groups=1, relu=relu, bn=bn, bias=bias)
def forward(self,x):
out = self.conv(x)
out = out + x
return out
class HyraricalResduial(nn.Module):
def __init__(self,in_planes,out_planes):
super(HyraricalResduial,self).__init__()
self.top_res = BasicConv(in_planes,out_planes,dilation=1,padding=1)
self.middle_res = BasicConv(in_planes,out_planes,dilation=2,padding=2)
self.bottom_res = BasicConv(in_planes,out_planes,dilation=3,padding=3)
def forward(self,x):
out_top = self.top_res(x)
fuse_middle = x + out_top
out_middle = self.middle_res(fuse_middle)
fuse_bottom = x + out_middle
out_bottom = self.bottom_res(fuse_bottom)
out_final = x + out_top + out_middle + out_bottom
return out_final
class HRBLOCK(nn.Module):
def __init__(self, in_dim, out_dim):
super(HRBLOCK, self).__init__()
self.conv1 = Conv3x3(in_dim, out_dim, 3, 1)
self.relu1 = nn.PReLU()
self.conv2 = HyraricalResduial(in_planes = out_dim, out_planes = out_dim)
# T^{l}_{1}: (conv.)
self.up_conv =nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=2, dilation=2, bias=False)
self.up_relu = nn.PReLU()
#self.se = AWCA(res_dim)
self.se = Modified_CBAM(gate_channels=out_dim, reduction_ratio=16, no_spatial=True)
def forward(self, x):
out = self.relu1(self.conv1(x))
x_r = out
out = self.conv2(out)
# T^{l}_{1}
out = self.up_conv(out)
out = self.up_relu(out)
out = self.se(out)
out += x_r
return out
class BackBone(nn.Module):
def __init__(self, inplanes=3, planes=31, channels=200, n_DRBs=8):
super(BackBone, self).__init__()
# 2D Nets
self.input_conv2D = Conv3x3(inplanes, channels, 3, 1)
self.backbone = nn.ModuleList(
[HRBLOCK(in_dim=channels, out_dim=channels) for _ in
range(n_DRBs)])
self.cblock = HRBLOCK(in_dim = n_DRBs*channels,out_dim = channels)
self.oblock = HRBLOCK(in_dim = channels,out_dim = channels)
self.output_prelu2D = nn.PReLU()
self.output_conv2D = Conv3x3(channels, planes, 3, 1)
#self.tail_nonlocal = PSNL(planes)
def forward(self, x):
out = self.DRN2D(x)
return out
def DRN2D(self, x):
input_x = x
fea_out = None
out = self.input_conv2D(input_x)
residual = out
for i, block in enumerate(self.backbone):
out = block(out)
if i > 0:
fea_out = torch.cat([fea_out, out], dim=1)
else:
fea_out = out
out = self.cblock(fea_out)
out = self.oblock(out)
out = torch.add(out, residual)
out = self.output_conv2D(self.output_prelu2D(out))
#out = self.tail_nonlocal(out)
return out
and the code for the main file
def main():
cudnn.benchmark = True
# load dataset
print("\nloading dataset ...")
train_dataset = HyperDatasetTrain1(mode='train')
test_dataset = HyperDatasetValid(mode='valid')
print("Train set sample:%d," % (len(train_dataset),))
print("Validation set samples: ", len(test_dataset))
# Data Loader (Input Pipeline)
train_loader1 = DataLoader(dataset=train_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
train_loader = [train_loader1]
val_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)
#torch.autograd.set_detect_anomaly(True)
viz = visdom.Visdom(env="proposed-model1")
if not viz.check_connection():
print("Visdom is not connected. Did you run 'python -m visdom.server' ?")
# model
print("\nbuilding models_baseline ...")
#model = AWAN(3, 31, 200, 8)
model = BackBone(3,31,128,6)
#model = FMNet(bNum=3, nblocks=5, input_channels=31, num_features=64, out_channels=31)
print('Parameters number is ', sum(param.numel() for param in model.parameters()))
criterion_train = LossTrainCSS()
criterion_train_L1 = torch.nn.L1Loss().cuda()
criterion_valid = Loss_valid()
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model) # batchsize integer times
if torch.cuda.is_available():
model.cuda()
criterion_train.cuda()
criterion_valid.cuda()
# Parameters, Loss and Optimizer
start_epoch = 0
iteration = 0
record_val_loss = 1000
#optimizer = optim.Adam(model.parameters(), lr=opt.init_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
optimizer = optim.Adam(model.parameters(), lr = opt.init_lr, betas = (opt.b1, opt.b2), weight_decay = opt.weight_decay)
# visualzation
if not os.path.exists(opt.outf):
os.makedirs(opt.outf)
loss_csv = open(os.path.join(opt.outf, 'loss.csv'), 'a+')
log_dir = os.path.join(opt.outf, 'train.log')
logger = initialize_logger(log_dir)
# Resume
resume_file = opt.outf + '/best_net_21epoch.pth'
#resume_file = ''
if resume_file:
if os.path.isfile(resume_file):
print("=> loading checkpoint '{}'".format(resume_file))
checkpoint = torch.load(resume_file)
start_epoch = checkpoint['epoch']
iteration = checkpoint['iter']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
# start epoch
for epoch in range(start_epoch+1, opt.end_epoch):
start_time = time.time()
total_loss, train_loss, rgb_train_loss ,lr,iteration = train(train_loader, model, criterion_train, criterion_train_L1,optimizer, epoch, iteration, opt)
val_loss = test(model, val_loader, criterion_valid)
# Save model either the best model so far or every 10 epochs
if torch.abs(val_loss - record_val_loss) < 0.0001 or val_loss < record_val_loss:
save_checkpoint(opt.outf, epoch, iteration, model, optimizer,best=True)
if val_loss < record_val_loss:
record_val_loss = val_loss
if epoch %10==0:
save_checkpoint(opt.outf, epoch, 1000, model, optimizer,best=False)
# print loss
end_time = time.time()
epoch_time = end_time - start_time
print("Epoch [%02d], Iter[%06d], Time:%.9f, learning rate : %.9f, Train Loss: %.9f Test Loss: %.9f "
% (epoch, iteration, epoch_time, lr, train_loss, val_loss))
#for learning rate
viz.line([optimizer.param_groups[0]['lr']*(10**4)],[epoch],win='Learning rate schedule',update='append',
opts=dict(title='Learning rate schedule',
legend=['lr*(10^4)']))
#for HSI train loss
viz.line([train_loss.detach().cpu()],[epoch],win='HSI Train_loss',
update='append',opts=dict(title=' HSI Train Learning Curve.',
legend=['HSI Train Loss']))
#for validation loss
viz.line([val_loss.detach().cpu()],[epoch],win='Val Train_loss',
update='append',opts=dict(title='val loss Learning Curve.',
legend=['val loss']))
#for rgb train loss
viz.line([rgb_train_loss.detach().cpu()],[epoch],win=' RGB Train_loss',
update='append',opts=dict(title='rgb train loss Learning Curve.',
legend=['rgb train loss']))
#for total train loss
viz.line([total_loss.detach().cpu()],[epoch],win='Total Train_loss',
update='append',opts=dict(title='total train loss Learning Curve.',
legend=['total train loss']))
# for train_loss and validation_loss
viz.line([[train_loss.detach().cpu(),val_loss.detach().cpu()]],[epoch],win='Train_loss and val_loss',
update='append',opts=dict(title='Learning Curve.',
legend=['Train Loss', 'Validation Loss']))
# save loss
record_loss(loss_csv,epoch, train_loss, val_loss)
logger.info("Epoch [%02d], Train Loss: %.9f Test Loss: %.9f "
% (epoch, train_loss, val_loss))
and the code for the training and evaluation loop is as follow
def train(train_loader, model, criterion_train,criterion_train_L1, optimizer, epoch, iteration, opt):
total_loss = AverageMeter()
losses = AverageMeter()
losses_rgb = AverageMeter()
#random.shuffle(train_loader)
prev_time = time.time()
model.train()
for k, train_data_loader in enumerate(train_loader):
for i,data in enumerate(train_data_loader):
#with torch.autograd.set_detect_anomaly(True):
images, labels = data
# to only test the labels that having values other than 0 and 1
if images.min()<0 or labels.max()>1:
print("yes there is problem in labels ")
logger2.info("Epoch [%02d],batch no: %d/%d"
% (epoch,i+1,len(train_data_loader)))
images, labels = images.cuda(), labels.cuda()
model.zero_grad()
optimizer.zero_grad()
lr = poly_lr_scheduler(optimizer, opt.init_lr, iteration, max_iter=opt.max_iter, power=opt.decay_power)
iteration = iteration + 1
#lr_scheduler.step()
fake_hyper = model.forward(images)
#loss = criterion_train_L1(fake_hyper, labels)
loss , loss_rgb = criterion_train(fake_hyper, labels, images)
loss_all = loss + opt.trade_off * loss_rgb
loss_all.backward()
optimizer.step()
# # Determine approximate time left
iters_done = epoch *len(train_loader)* len(train_data_loader) + i
iters_left =opt.end_epoch*len(train_loader)* len(train_data_loader) - iters_done
time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time))
prev_time = time.time()
# record loss
losses.update(loss.data)
losses_rgb.update(loss_rgb.data)
total_loss.update(loss_all.data)
print('[Epoch:%02d],[Batch NO:%d/%d],[iter:%d],lr=%.9f,[Time_left=%s],[train_losses.avg=%.9f], [rgb_train_losses.avg=%.9f]'
% (epoch, i+1, len(train_data_loader), iteration,lr, time_left,losses.avg, losses_rgb.avg))
return total_loss.avg, losses.avg,losses_rgb.avg ,lr ,iteration
def test(model, test_dataset, criterion):
model.eval()
losses = AverageMeter()
for i, data in enumerate(test_dataset):
images,labels = data
images, labels = images.cuda(), labels.cuda()
with torch.no_grad():
fake_hyper = model.forward(images)
loss = criterion(fake_hyper, labels)
losses.update(loss.data)
return losses.avg
# Learning rate
def poly_lr_scheduler(optimizer, init_lr, iteraion, lr_decay_iter=1, max_iter=100, power=0.9):
"""Polynomial decay of learning rate
:param init_lr is base learning rate
:param iter is a current iteration
:param lr_decay_iter how frequently decay occurs, default is 1
:param max_iter is number of maximum iterations
:param power is a polymomial power
"""
if iteraion % lr_decay_iter or iteraion > max_iter:
return optimizer
lr = init_lr*(1 - iteraion/max_iter)**power
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
if you need the code for the dataset class let me know. Thanks alot again for your concern
Hi Ahmed -
This is quite a bit of code
One quick check would be to add a torch.copy()
in the training and testing loops:
for i,data in enumerate(train_data_loader):
images, labels = data
images = torch.copy(images)
labels = torch.copy(labels)
And check if this fixes the problem. This doesn’t pinpoint the bug, but it at least confirms that we’re dealing with some inadvertent changing-in-place somewhere inside the forward call.
(by the way, how do you know the pixel values are changing?)
torch.copy
doesn’t exist in pytorch version 1.10.0 so I used img_A = torch.detach(img_A1).clone()
and unfortunately it doesn’t work either. The values of the labels keeps changes until having nan values and makes the gradient explodes. i am using this control statement inside the train loop to check for the labels at every iteration
# to only test the labels that having values other than 0 and 1
if images.min()<0 or labels.max()>1:
print("yes there is problem in labels ")
logger2.info("Epoch [%02d],batch no: %d/%d"
% (epoch,i+1,len(train_data_loader)))
and the complete train loop is as follow
def train(train_loader, model, criterion_train,criterion_train_L1, optimizer, epoch, iteration, opt):
total_loss = AverageMeter()
losses = AverageMeter()
losses_rgb = AverageMeter()
#random.shuffle(train_loader)
prev_time = time.time()
model.train()
for k, train_data_loader in enumerate(train_loader):
for i,data in enumerate(train_data_loader):
#with torch.autograd.set_detect_anomaly(True):
images, labels = data
# to only test the labels that having values other than 0 and 1
if images.min()<0 or labels.max()>1:
print("yes there is problem in labels ")
logger2.info("Epoch [%02d],batch no: %d/%d"
% (epoch,i+1,len(train_data_loader)))
images, labels = images.cuda(), labels.cuda()
model.zero_grad()
optimizer.zero_grad()
lr = poly_lr_scheduler(optimizer, opt.init_lr, iteration, max_iter=opt.max_iter, power=opt.decay_power)
iteration = iteration + 1
#lr_scheduler.step()
fake_hyper = model.forward(images)
#loss = criterion_train_L1(fake_hyper, labels)
loss , loss_rgb = criterion_train(fake_hyper, labels, images)
loss_all = loss + opt.trade_off * loss_rgb
loss_all.backward()
optimizer.step()
# # Determine approximate time left
iters_done = epoch *len(train_loader)* len(train_data_loader) + i
iters_left =opt.end_epoch*len(train_loader)* len(train_data_loader) - iters_done
time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time))
prev_time = time.time()
# record loss
losses.update(loss.data)
losses_rgb.update(loss_rgb.data)
total_loss.update(loss_all.data)
print('[Epoch:%02d],[Batch NO:%d/%d],[iter:%d],lr=%.9f,[Time_left=%s],[train_losses.avg=%.9f], [rgb_train_losses.avg=%.9f]'
% (epoch, i+1, len(train_data_loader), iteration,lr, time_left,losses.avg, losses_rgb.avg))
return total_loss.avg, losses.avg,losses_rgb.avg ,lr ,iteration
Thanks a lot for your concern
Hi Ahmed -
How do you know that images & labels are changing, as opposed to a few of them being corrupted in the first place? If your dataloader is not randomizing, and you don’t get any error the first epoch, then that’s pretty good evidence. However, if your dataloader is randomizing and potentially not going through the entire data per epoch, the possibility remains that some datapoints are corrupted in the first place, and whenever you (randomly) sample them, you raise an error.
Are you using any data augmentation anywhere that could possibly be transforming the raw data (with some bug)?
Is there a possibility that some other process is altering either your images or labels while you’re performing training? Not sure what your setup is but worth checking.
Just asking about all these more basic things first, since it would be good to entirely rule them out before going hunting for a fancier bug.
If you’ve checked this stuff and it seems fine, then I would debug this by adding that control statement many times over inside your forward calls (with an added index so you know which of the control statements is printing the error). That will allow you to pinpoint the problematic step, and hopefully it turns out there’s a particular one causing it. This seems to me like the fastest way to figure out the issue here.
Thanks,
Andrei
thanks for your reply. I am checking the dataset before training to make sure that it is correctly normalized and don’t contain any nan values. I am not using any data augmentation in my code and the weird thing is that only the labels changes until corrupted with nan values and explodes the gradients. The network that i m using isn’t standardized and i am trying to design a new network architecture. The training loops over the entire dataset at every epoch but the nan values are always obtains from epoch three to five. I only deal with the labels at the loss function. I am trying to add a tiny bit of code every time and train to five epochs to catch the problem. I think there is in-place operations as you suggest in the first place and i am checking them but till now no luck with the training to be stable.
Thanks a lot for your suggestions and debugging
Hi Ahmed -
That’s useful. It looks like the labels get passed to criterion_train
, which is an instance of LossTrainCSS()
– do you have the code for that?
That could be where we will find the bug.
Andrei
Thanks for the reply. Here is the code for the loss of both valid and train loops
class Loss_valid(nn.Module):
def __init__(self):
super(Loss_valid, self).__init__()
def forward(self, outputs, label):
error = torch.abs(outputs - label) / label
# error = torch.abs(outputs - label)
rrmse = torch.mean(error.view(-1))
return rrmse
class LossTrainCSS(nn.Module):
def __init__(self):
super(LossTrainCSS, self).__init__()
self.model_hs2rgb = nn.Conv2d(31, 3, 1, bias=False)
filtersPath = './cie_1964_w_gain.npz'
cie_matrix = np.load(filtersPath)['filters']
cie_matrix = torch.from_numpy(np.transpose(cie_matrix, [1, 0])).unsqueeze(-1).unsqueeze(-1).float()
self.model_hs2rgb.weight.data = cie_matrix
def forward(self, outputs, label, rgb_label):
rrmse = self.mrae_loss(outputs, label)
# hs2rgb
with torch.no_grad():
rgb_tensor = self.model_hs2rgb(outputs)
rgb_tensor = rgb_tensor / 255
rgb_tensor = torch.clamp(rgb_tensor, 0, 1) * 255
# rgb_tensor = torch.tensor(rgb_tensor, dtype=torch.uint8)
# rgb_tensor = torch.tensor(rgb_tensor, dtype=torch.uint8)
# update from torch it self is the line below , the original line is below
# the written one
rgb_tensor = rgb_tensor.clone().detach().byte().float()
#rgb_tensor = torch.tensor(rgb_tensor).byte().float()
rgb_tensor = rgb_tensor / 255
rrmse_rgb = self.rgb_mrae_loss(rgb_tensor, rgb_label)
return rrmse, rrmse_rgb
def mrae_loss(self, outputs, label):
error = torch.abs(outputs - label) / label
mrae = torch.mean(error.view(-1))
return mrae
def rgb_mrae_loss(self, outputs, label):
error = torch.abs(outputs - label)
mrae = torch.mean(error.view(-1))
return mrae
Thanks again
Can label
take the value 0?
Very interesting question. The label in a hyperspectral image in another word with dimensions NxCxWxH and the division is an image over a complete image and the result is also an image before obtaining the mean. I also check for the min and max values of the labels and the minimum values is in all labels roughly about 0.0001 . Note that also the loss values start at 1.0 and reaches to 0.03 so all the numbers are small and should cancel each other
I see, you might consider flooring that denominator just to make sure you don’t explode and overflow, by replacing it with something like torch.max(torch.Tensor([0.01]), label))
This would be pretty weird, but is it possible that your .backward()
call is updating the labels, given that your loss is a Module? Earlier when I suggested torch.copy
and you said you did .detach().clone()
you said you did it for img_A1
, did you also do it for the labels?
Other than that, I’m out of ideas, except anomaly detection and lots of debug print statements to pinpoint when exactly the problem occurs.
.detach().clone()
I did this function for both images and labels. I will try to simplify things up to the very basic stuff at first until know the possible reason for the problem. I am really very grateful for helping me out this long. Thanks a lot again