Why iterative algorithm using Tensors on GPU is slower than Tensors on CPU?

Hi, I’m new to this forum and the Pytorch. I’m trying to implement ADMM algorithm to solve an optimization problem (see the code below). In my application, the sizes of the tensors are:
b.shape = (16, 1, 256, 1)
w.shape = (16, 1, 1024, 1)
A_down.shape = (256, 1024)
A_up.shape = (1024, 256)
A_AT.shape = (256, 256)
DCT.shape = (1024, 1024)
IDCT.shape = (1024, 1024)

When I run the code below with the tensors on the GPU, it is twice slower than when the tensors are on the CPU. When I run the non-batched Numpy version it is atleast 5 times faster on 4-core CPU, e.g.

num_cores = multiprocessing.cpu_count()
outs = Parallel(n_jobs=num_cores)(delayed(L1L1_Solver)(lr_inputs[i].squeeze(), cnn_outputs[i].squeeze(), beta, skip, AOp, AATOp, SOp) for i in range(batch_size))

Am I using Pytorch the wrong way? Any suggestion to accelerate the code below? Thanks!

def  L1L1_Solver_TORCH(b: torch.Tensor, w: torch.Tensor, beta: float, skip: int, A_down: torch.Tensor, A_up: torch.Tensor, A_AT: torch.Tensor, DCT: torch.Tensor=None, IDCT: torch.Tensor=None):
    """
    Solves the problem:

            minimize    ||x||_1 + beta*||x-w||_1
                x
            subject to  Ax = b

    or the problem (when Skip Connection):
        minimize    ||x||_1 + beta*||x + w||_1
                x
            subject to  Ax = b - Aw
    
    where b: m x 1,  A: m x n,  w: n x 1, beta > 0 and  m < n.
    Python implementation of L1-L1 solver.
    Details can be found on: https://github.com/joaofcmota/cs-with-prior-information
    """
    B, C, P, Q = b.shape
    B, C, M, N = w.shape
    n = M*N
    m = P*Q

    # Flatten the input: b and w
    b = b.reshape(B, C, m, 1)
    w = w.reshape(B, C, n, 1)

    if DCT is not None:  # Convert w in to Sparse Domain
        w = DCT @ w

    # For Skip connection across the L1L1-Solver Layer
    if skip == 1:
        w = -1.0 * w
        beta = torch.as_tensor(1/beta, device=b.device)
        b = b + A_down @ w

    # Algorithm Parameters
    rho = torch.as_tensor(1.0, device=b.device)
    tau_rho = torch.as_tensor(10.0, device=b.device)
    mu_rho = torch.as_tensor(2.0, device=b.device)
    eps_prim = torch.as_tensor(1E-3, device=b.device)
    eps_dual = torch.as_tensor(1E-3, device=b.device)
    MAX_ITER = 1000

    # Initialization/preallocate of local variables
    aux1 = torch.zeros_like(b)
    lam = torch.zeros_like(w)
    x = torch.clone(w)
    y = torch.clone(w)
    z = torch.zeros_like(x)
    Az_minus_b = torch.zeros_like(b)
    v = torch.zeros_like(y)
    rhow = torch.zeros_like(w)
    w_pos = torch.zeros(w.shape, dtype=torch.bool, device=w.device)
    not_w_pos = torch.zeros_like(w_pos)
    indices = torch.zeros_like(w_pos)
    zero = torch.as_tensor(0.0, device=b.device)
    one = torch.as_tensor(1.0, device=b.device)

    # Main Iteration
    for k in range(MAX_ITER):
        #======================================== x-minimization ================================#
        v = lam - rho * y
        rhow = rho * w
        w_pos = (w >= 0.0)

        # Components for which w_i >= 0
        indices = w_pos & (v < -rhow - beta - one)
        x[indices] = (-beta - one - v[indices])/rho

        indices = w_pos & (-rhow - beta - one <= v) & (v <= -rhow + beta - one)
        x[indices] = w[indices]

        indices = w_pos & (-rhow + beta - one < v) & (v < beta - one)
        x[indices] = (beta - one - v[indices])/rho

        indices = w_pos & (beta - one <= v) & (v <= beta + one)
        x[indices] = zero

        indices = w_pos & (v > beta + one)
        x[indices] = (beta + one - v[indices]) / rho

        # Components for which w_i < 0
        not_w_pos = ~w_pos
        indices = not_w_pos & (v < -beta - one)
        x[indices] = (-beta - one - v[indices])/rho

        indices = not_w_pos & (-beta - one <= v) & (v <= -beta + one)
        x[indices] = zero

        indices = not_w_pos & (-beta + one < v) & (v < -rhow - beta + one)
        x[indices] = (-beta + one - v[indices]) / rho

        indices = not_w_pos & (-rhow - beta + one <= v) & (v <= -rhow + beta + one)
        x[indices] = w[indices]

        indices = not_w_pos & (v > -rhow + beta + one)
        x[indices] = (beta + one - v[indices]) / rho
        #=========================== y-minimization  ===========================#
        y_prev = torch.clone(y)
        z = (lam + (rho*x)) / rho

        Az_minus_b = (A_down @ z) - b
        torch.linalg.solve(A_AT, Az_minus_b, out=aux1)  # A_AT is Full Rank Matrix 
        y = z - A_up @ aux1
        #===================== Update dual variable ==============================#
        r_prim = x - y  # primal residual

        lam = lam + rho * r_prim

        s_dual = -rho * (y - y_prev)  # dual residual
        #============================= rho parameter adjustment  ======================#
        r_prim_norm = torch.linalg.norm(r_prim)
        s_dual_norm = torch.linalg.norm(s_dual)

        if r_prim_norm > tau_rho * s_dual_norm:
            rho = mu_rho*rho
        elif s_dual_norm > tau_rho * r_prim_norm:
            rho = rho/mu_rho

        if (r_prim_norm < eps_prim) and (s_dual_norm < eps_dual):
            break
    #===================================== End of Loop ======================================= #

    x_opt = torch.clone(x)

    if IDCT is not None:  # Convert back to Image Domain
        x_opt = IDCT @ x_opt

    x_opt = x_opt.reshape(B, C, M, N)

    return x_opt