I am writing a decode function use pytorch and this function will be used decode the full connect layer output,but when I run this code ,GPU did not work,only CPU work,this is my decode function:
def decoder(x):
colormax = x[:,0:49]
colormin = x[:,49:98]
modulationData = x[:,98:147]
fresult = torch.Tensor(64,784)
if args.cuda:
fresult=fresult.cuda()
# fresult=Variable(fresult)
blocks = 7
blockMask = blocks - 1
WEIGHTS1 = weights[0:4][:]
for y in range(blocks):
for x in range(blocks):
mod = modulationData[:,y * 7 + x:y * 7 + x+1]
inds = 0
for py in range(4):
yOffset = -1 if (py < 2) else 0
y0 = (y + yOffset) % blocks
y1 = (y0 + 1) % blocks
for px in range(4):
xOffset = -1 if (px < 2) else 0
x0 = (x + xOffset) % blocks
x1 = (x0 + 1) % blocks
ca = colormax[:,y0 * 7 + x0:y0 * 7 + x0+1] * factor[inds][0] + colormax[:,y1 * 7 + x0:y1 * 7 + x0+1] * factor[inds][1] + colormax[:,y0 * 7 + x1:y0 * 7 + x1+1] * factor[inds][2] + colormax[:,y1 * 7 + x1:y1 * 7 + x1+1] * factor[inds][3]
cb = colormin[:,y0 * 7 + x0:y0 * 7 + x0+1] * factor[inds][0] + colormin[:,y1 * 7 + x0:y1 * 7 + x0+1] * factor[inds][1] + colormin[:,y0 * 7 + x1:y0 * 7 + x1+1] * factor[inds][2] + colormin[:,y1 * 7 + x1:y1 * 7 + x1+1] * factor[inds][3]
w = np.array([WEIGHTS1[int(i) % 4][:] for i in mod])
w = torch.from_numpy(w).float()
fresult[:,(4 * y + py) * 28 + (4 * x + px)] = (torch.mul(ca.data.cpu(),w[:,0:1]) + torch.mul(cb.data.cpu(),w[:,1:2])) / 128
inds = inds + 1
mod = torch.mul(mod, 0.25)
fresult=Variable(fresult, requires_grad=True)
return fresult
And this is the network:
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x=decoder(x)
x = F.relu(x)
return x
what is wrong with my code,why the GPU didn't work,only CPU work?
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
data=data.view(-1,784)
optimizer.zero_grad()
output = model(data)
loss = F.mse_loss(output, data)