the code is below:
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
1.I have test the mean value of the DIV2K HR trainset
mean=[0.4485, 0.4375, 0.4045] std=[0.2436, 0.2330, 0.2424]
the mean is similar to the author, while the std is totally different
so how to set the std?
2.My own trainset’s mean=[0.5164, 0.5179, 0.4987],std=[0.2256, 0.2194, 0.2282]
I set the mean in my own trainset and std is [1.0,1.0,1.0]
the model’s code is below:
# RGB mean for DIV2K
#rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_mean=(0.5164,0.5179,0.4987)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std)
# define head module
modules_head = [conv(n_colors, n_feats, kernel_size)]
# define body module
modules_body = [
ResidualGroup(
conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \
for _ in range(n_resgroups)]
modules_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
modules_tail = [
common.Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, n_colors, kernel_size)]
self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
when i add the mean_shift,the network can not work well
when i remove the mean_shift layers,the network works well
so,what’s the mean of the mean_shift?and why network has a so bad result when i add it?
and my code of calculate the mean and std is in below:
img_list=sorted([os.path.join(dir,x) for x in glob.glob(dir+'*H.png')])
print(len(img_list))
class MyDataset(Dataset):
def __init__(self,img_list):
self.data =img_list
def __getitem__(self, index):
#x = self.data[index]
img=self.data[index]
return ToTensor()(Image.open(img))
def __len__(self):
return len(self.data)
dataset = MyDataset(img_list)
loader = DataLoader(
dataset,
batch_size=1,
num_workers=1,
shuffle=False
)
mean = 0.
std = 0.
nb_samples = 0.
i=0
for data in tqdm(loader):
#print(type(data))
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
nb_samples += batch_samples
i=i+1
mean /= nb_samples
std /= nb_samples
print(i,mean,std)