my network weight didn’t update,when I add the decoder function as below: x is the full connect layer output,what’s wrong with my decoder function?
def decoder(x):
colormax = x[:,0:192].view(-1,3,64)
colormin = x[:,192:384].view(-1,3,64)
modulationData = x[:,384:448]
M=torch.cat((modulationData, modulationData, modulationData, modulationData, modulationData, modulationData, modulationData, modulationData,modulationData,
modulationData, modulationData, modulationData, modulationData, modulationData, modulationData, modulationData), 1)
AM=torch.cat((x[:,0:256],x[:,0:256],x[:,0:256],x[:,0:256],x[:,0:256],x[:,0:256],x[:,0:256],
x[:,0:256],x[:,0:256],x[:,0:256],x[:,0:256],x[:,0:256]),1)
BM = torch.cat((x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256],
x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256]), 1)
fresult = torch.cat((x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256],
x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256], x[:, 0:256]), 1)
M = M.view(-1, 32, 32)
AM=AM.view(-1,3, 32, 32)
BM=BM.view(-1,3, 32, 32)
fresult = fresult.view(-1,3, 32, 32)
blocks = 8
for y in range(blocks):
for x in range(blocks):
mod = modulationData[:,y * blocks + x:y * blocks + x+1]
max = colormax[:, :, y * blocks + x]
min = colormin[:, :, y * blocks + x]
for py in range(4):
for px in range(4):
M[:,py,px]=mod%4
mod=mod/4
# print(AM[:, :, py, px].size())
# print(max.size())
AM[:, :, py, px] = max
BM[:, :, py, px] = min
fresult[:, 0, :, :] = (M * 2 * AM[:, 0, :, :] + (4 - M) * 2 * BM[:, 0, :, :])/128
fresult[:, 1, :, :] = (M * 2 * AM[:, 1, :, :] + (4 - M) * 2 * BM[:, 1, :, :])/128
fresult[:, 2, :, :] = (M * 2 * AM[:, 2, :, :] + (4 - M) * 2 * BM[:, 2, :, :])/128
fresult=fresult.view(-1,3072)
return fresult