when I add the decode function in my network the weight didn’t update,the decode function as below:
def decoder(x):
colormax = x[:,0:192].data.cpu().numpy().reshape((-1,3,64))
colormin = x[:,192:384].data.cpu().numpy().reshape((-1,3,64))
modulationData = x[:,384:448].data.cpu().numpy()
fresult = torch.ones((modulationData.shape[0], 3, 1024))
fresult=fresult.view((modulationData.shape[0],3072))
fresult=Variable(fresult, requires_grad=True)
return fresult
my network is that:
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x=decoder(x)
x = F.relu(x)
return x
what’s wrong with my decode function?