# Why iterative algorithm using Tensors on GPU is slower than Tensors on CPU?

Hi, I’m new to this forum and the Pytorch. I’m trying to implement ADMM algorithm to solve an optimization problem (see the code below). In my application, the sizes of the tensors are:
b.shape = (16, 1, 256, 1)
w.shape = (16, 1, 1024, 1)
A_down.shape = (256, 1024)
A_up.shape = (1024, 256)
A_AT.shape = (256, 256)
DCT.shape = (1024, 1024)
IDCT.shape = (1024, 1024)

When I run the code below with the tensors on the GPU, it is twice slower than when the tensors are on the CPU. When I run the non-batched Numpy version it is atleast 5 times faster on 4-core CPU, e.g.

``````num_cores = multiprocessing.cpu_count()
outs = Parallel(n_jobs=num_cores)(delayed(L1L1_Solver)(lr_inputs[i].squeeze(), cnn_outputs[i].squeeze(), beta, skip, AOp, AATOp, SOp) for i in range(batch_size))
``````

Am I using Pytorch the wrong way? Any suggestion to accelerate the code below? Thanks!

``````def  L1L1_Solver_TORCH(b: torch.Tensor, w: torch.Tensor, beta: float, skip: int, A_down: torch.Tensor, A_up: torch.Tensor, A_AT: torch.Tensor, DCT: torch.Tensor=None, IDCT: torch.Tensor=None):
"""
Solves the problem:

minimize    ||x||_1 + beta*||x-w||_1
x
subject to  Ax = b

or the problem (when Skip Connection):
minimize    ||x||_1 + beta*||x + w||_1
x
subject to  Ax = b - Aw

where b: m x 1,  A: m x n,  w: n x 1, beta > 0 and  m < n.
Python implementation of L1-L1 solver.
Details can be found on: https://github.com/joaofcmota/cs-with-prior-information
"""
B, C, P, Q = b.shape
B, C, M, N = w.shape
n = M*N
m = P*Q

# Flatten the input: b and w
b = b.reshape(B, C, m, 1)
w = w.reshape(B, C, n, 1)

if DCT is not None:  # Convert w in to Sparse Domain
w = DCT @ w

# For Skip connection across the L1L1-Solver Layer
if skip == 1:
w = -1.0 * w
beta = torch.as_tensor(1/beta, device=b.device)
b = b + A_down @ w

# Algorithm Parameters
rho = torch.as_tensor(1.0, device=b.device)
tau_rho = torch.as_tensor(10.0, device=b.device)
mu_rho = torch.as_tensor(2.0, device=b.device)
eps_prim = torch.as_tensor(1E-3, device=b.device)
eps_dual = torch.as_tensor(1E-3, device=b.device)
MAX_ITER = 1000

# Initialization/preallocate of local variables
aux1 = torch.zeros_like(b)
lam = torch.zeros_like(w)
x = torch.clone(w)
y = torch.clone(w)
z = torch.zeros_like(x)
Az_minus_b = torch.zeros_like(b)
v = torch.zeros_like(y)
rhow = torch.zeros_like(w)
w_pos = torch.zeros(w.shape, dtype=torch.bool, device=w.device)
not_w_pos = torch.zeros_like(w_pos)
indices = torch.zeros_like(w_pos)
zero = torch.as_tensor(0.0, device=b.device)
one = torch.as_tensor(1.0, device=b.device)

# Main Iteration
for k in range(MAX_ITER):
#======================================== x-minimization ================================#
v = lam - rho * y
rhow = rho * w
w_pos = (w >= 0.0)

# Components for which w_i >= 0
indices = w_pos & (v < -rhow - beta - one)
x[indices] = (-beta - one - v[indices])/rho

indices = w_pos & (-rhow - beta - one <= v) & (v <= -rhow + beta - one)
x[indices] = w[indices]

indices = w_pos & (-rhow + beta - one < v) & (v < beta - one)
x[indices] = (beta - one - v[indices])/rho

indices = w_pos & (beta - one <= v) & (v <= beta + one)
x[indices] = zero

indices = w_pos & (v > beta + one)
x[indices] = (beta + one - v[indices]) / rho

# Components for which w_i < 0
not_w_pos = ~w_pos
indices = not_w_pos & (v < -beta - one)
x[indices] = (-beta - one - v[indices])/rho

indices = not_w_pos & (-beta - one <= v) & (v <= -beta + one)
x[indices] = zero

indices = not_w_pos & (-beta + one < v) & (v < -rhow - beta + one)
x[indices] = (-beta + one - v[indices]) / rho

indices = not_w_pos & (-rhow - beta + one <= v) & (v <= -rhow + beta + one)
x[indices] = w[indices]

indices = not_w_pos & (v > -rhow + beta + one)
x[indices] = (beta + one - v[indices]) / rho
#=========================== y-minimization  ===========================#
y_prev = torch.clone(y)
z = (lam + (rho*x)) / rho

Az_minus_b = (A_down @ z) - b
torch.linalg.solve(A_AT, Az_minus_b, out=aux1)  # A_AT is Full Rank Matrix
y = z - A_up @ aux1
#===================== Update dual variable ==============================#
r_prim = x - y  # primal residual

lam = lam + rho * r_prim

s_dual = -rho * (y - y_prev)  # dual residual
r_prim_norm = torch.linalg.norm(r_prim)
s_dual_norm = torch.linalg.norm(s_dual)

if r_prim_norm > tau_rho * s_dual_norm:
rho = mu_rho*rho
elif s_dual_norm > tau_rho * r_prim_norm:
rho = rho/mu_rho

if (r_prim_norm < eps_prim) and (s_dual_norm < eps_dual):
break
#===================================== End of Loop ======================================= #

x_opt = torch.clone(x)

if IDCT is not None:  # Convert back to Image Domain
x_opt = IDCT @ x_opt

x_opt = x_opt.reshape(B, C, M, N)

return x_opt
``````