I have tried to implement a “complicated” (for me…) loss function (a porting from MATLAB repo).
Strangely it compiles, but when I try to call the .backward()
function it returns to me a NoneType.
I’m analyzing the code but, even if there are many things that I’m not understanding, I don’t find the real cause of the problem.
Could someone help me in this desperate effort?
The code to reproduce the error is:
import torch
import math
import numpy as np
from torch import nn
def normalize_block(im):
m = torch.mean(im)
s = torch.std(im)
if s == 0:
s = 1e-7
y = ((im-m)/s)+1
return y, m, s
def onion_mult(onion1, onion2):
bs, N = onion1.size()
if N>1:
L = int(N/2)
a = onion1[:, :L]
b = onion1[:, L:]
b = torch.cat((torch.unsqueeze(b[:,0], 1), -b[:,1:]), dim=1)
c = onion2[:, :L]
d = onion2[:, L:]
d = torch.cat((torch.unsqueeze(d[:, 0], 1), -d[:, 1:]), dim=1)
if N == 2:
ris = torch.cat(((a*c)-(d*b), (a*d)+(c*b)), dim=1)
else:
ris1 = onion_mult(a, c)
ris2 = onion_mult(d, torch.cat((torch.unsqueeze(b[:, 0], 1), -b[:, 1:]), dim=1))
ris3 = onion_mult(torch.cat((torch.unsqueeze(a[:, 0], 1), -a[:, 1:]), dim=1), d)
ris4 = onion_mult(c, b)
aux1 = ris1 - ris2
aux2 = ris3 + ris4
ris = torch.cat((aux1, aux2), dim=1)
else:
ris = onion1 * onion2
return ris
def onion_mult2D(onion1, onion2):
bs, dim3, _, _ = onion1.size()
if (dim3 > 1):
L = int(dim3/2)
a = onion1[:, 0:L, :, :]
b = onion1[:, L:, :, :]
b = torch.cat((torch.unsqueeze(b[:,0,:,:], 1), -b[:,1:, :, :],), dim=1)
c = onion2[:, 0:L, :, :]
d = onion2[:, L:, :, :]
d = torch.cat((torch.unsqueeze(d[:, 0, :, :], 1), -d[:, 1:, :, :],), dim=1)
if dim3 == 2:
ris = torch.cat(((a*c)-(d*b), (a*d)+(c*b)),dim=1)
else:
ris1 = onion_mult2D(a,c)
ris2 = onion_mult2D(d, torch.cat((torch.unsqueeze(b[:,0,:,:], 1), -b[:, 1:, :, :]),dim=1))
ris3 = onion_mult2D(torch.cat((torch.unsqueeze(a[:,0,:,],1), -a[:, 1:, :, :]), dim=1), d)
ris4 = onion_mult2D(c,b)
aux1 = ris1-ris2
aux2 = ris3+ris4
ris = torch.cat((aux1, aux2), dim=1)
else:
ris = onion1 * onion2
return ris
def onions_quality(im1, im2, size, device):
im1 = im1.type(torch.double).to(device)
im2 = im2.type(torch.double).to(device)
im2 = torch.cat((torch.unsqueeze(im2[:, 0, :, :], 1), -im2[:, 1:, :,:]), dim=1)
batch_size, dim3, _, _ = im1.size()
for bs in range(batch_size):
for i in range(dim3):
a1, s, t = normalize_block(im1[bs,i,:,:])
im1[bs,i,:,:] = a1
if s == 0:
if i == 0:
im2[bs,i,:,:] = im2[bs,i,:,:]-s+1
else:
im2[bs, i, :, :] = -(-im2[bs, i, :, :] - s + 1)
else:
if i == 0:
im2[bs, i, :, :] = ((im2[bs, i, :, :] - s)/t)+1
else:
im2[bs, i, :, :] = -(((-im2[bs, i, :, :] - s)/t)+1)
m1 = torch.mean(im1, dim=(2, 3))
m2 = torch.mean(im2, dim=(2, 3))
mod_q1m = torch.sqrt(torch.sum(m1**2, dim=1))
mod_q2m = torch.sqrt(torch.sum(m2**2, dim=1))
mod_q1 = torch.sqrt(torch.sum(im1 ** 2, dim=1))
mod_q2 = torch.sqrt(torch.sum(im2 ** 2, dim=1))
term2 = mod_q1m * mod_q2m
term4 = mod_q1m**2 + mod_q2m**2
temp = [size ** 2 / (size**2 - 1)] * batch_size
temp = torch.from_numpy(np.asarray(temp)).to(device)
int1 = torch.clone(temp)
int2 = torch.clone(temp)
int3 = torch.clone(temp)
int1 = int1 * torch.mean(mod_q1**2)
int2 = int2 * torch.mean(mod_q2**2)
int3 = int3 * (mod_q1m **2 + mod_q2m ** 2)
term3 = int1 + int2 - int3
mean_bias = 2*term2/term4
if term3 == 0:
q = torch.zeros((batch_size, 1, 1, dim3), device=device, requires_grad=True)
q[:,:,:,dim3-1] = mean_bias
else:
cbm = 2/term3
qu = onion_mult2D(im1, im2)
qm = onion_mult(m1, m2)
#qv = torch.zeros((batch_size, dim3), device=device, requires_grad=True)
#for bs in range(batch_size):
#for i in range(dim3):
#qv[:, i] = (size**2)/(size**2 - 1) * torch.mean(qu[:, i, :,:], dim=(2,3))
qv = (size ** 2) / (size ** 2 - 1) * torch.mean(qu, dim=(-2, -1))
q = qv - temp*qm
q = q*mean_bias*cbm
return q
class complicated_loss (nn.Module):
def __init__(self, device, Q_block_size=32, Q_shift=32):
super(complicated_loss, self).__init__()
self.Q_block_size = Q_block_size
self.Q_shift = Q_shift
self.device = device
def forward(self, outputs, labels):
bs, dim3, dim1, dim2 = labels.size()
_, _, ddim1, ddim2 = outputs.size()
stepx = math.ceil(dim1/self.Q_shift)
stepy = math.ceil(dim2/self.Q_shift)
if stepy <= 0:
stepy = 1
stepx = 1
est1 = (stepx - 1)*self.Q_shift+self.Q_block_size-dim1
est2 = (stepy - 1)*self.Q_shift+self.Q_block_size-dim2
if (est1 != 0)+(est2 != 0) > 0:
padding = torch.nn.ReflectionPad2d((0, est1, 0, est2))
reference = padding(labels)
fused = padding(outputs)
outputs = fused.type(torch.int16).to(self.device)
labels = reference.type(torch.int16).to(self.device)
bs, dim3, dim1, dim2 = labels.size()
if(math.ceil(math.log2(dim3)) - math.log2(dim3) != 0):
exp_difference = 2 ** (torch.ceil(torch.log2(dim3))) - dim3
diff = torch.zeros((bs, exp_difference, dim1, dim2), device=self.device, requires_grad=True).type(torch.int16)
labels = torch.cat((labels, diff), dim=1)
outputs = torch.cat((outputs, diff), dim=1)
bs, dim3, dim1, dim2 = labels.size()
values = torch.zeros((bs, dim3, stepx, stepy), device=self.device, requires_grad=True)
for j in range(stepx):
for i in range(stepy):
o = onions_quality(labels[:, :, j * self.Q_shift:j * self.Q_shift + self.Q_block_size, i * self.Q_shift : i * self.Q_shift + self.Q_block_size], outputs[:, :, j * self.Q_shift:j * self.Q_shift + self.Q_block_size, i * self.Q_shift : i * self.Q_shift + self.Q_block_size], self.Q_block_size, self.device)
values.data[:,:, j,i] = o
index_map = torch.sqrt(torch.sum(values**2, dim=1))
index = torch.mean(index_map)
loss = 1.0 - index
return loss
if __name__ == '__main__':
device = torch.device('cpu')
a = np.arange(256*256)
a = a.reshape(256,256)
a = a.astype('float32')
a = np.expand_dims(a, (0,1))
a = torch.from_numpy(a)
b = torch.zeros(a.size())
a.requires_grad = True
criterion = complicated_loss(device)
loss = criterion(a,b)
f = loss.backward()
The problem is that this loss does not update the weights of the network during the training loop. Where am I wrong?
Thank you!