Hi, I’m new to this forum and the Pytorch. I’m trying to implement ADMM algorithm to solve an optimization problem (see the code below). In my application, the sizes of the tensors are:
b.shape = (16, 1, 256, 1)
w.shape = (16, 1, 1024, 1)
A_down.shape = (256, 1024)
A_up.shape = (1024, 256)
A_AT.shape = (256, 256)
DCT.shape = (1024, 1024)
IDCT.shape = (1024, 1024)
When I run the code below with the tensors on the GPU, it is twice slower than when the tensors are on the CPU. When I run the non-batched Numpy version it is atleast 5 times faster on 4-core CPU, e.g.
num_cores = multiprocessing.cpu_count()
outs = Parallel(n_jobs=num_cores)(delayed(L1L1_Solver)(lr_inputs[i].squeeze(), cnn_outputs[i].squeeze(), beta, skip, AOp, AATOp, SOp) for i in range(batch_size))
Am I using Pytorch the wrong way? Any suggestion to accelerate the code below? Thanks!
def L1L1_Solver_TORCH(b: torch.Tensor, w: torch.Tensor, beta: float, skip: int, A_down: torch.Tensor, A_up: torch.Tensor, A_AT: torch.Tensor, DCT: torch.Tensor=None, IDCT: torch.Tensor=None):
"""
Solves the problem:
minimize ||x||_1 + beta*||x-w||_1
x
subject to Ax = b
or the problem (when Skip Connection):
minimize ||x||_1 + beta*||x + w||_1
x
subject to Ax = b - Aw
where b: m x 1, A: m x n, w: n x 1, beta > 0 and m < n.
Python implementation of L1-L1 solver.
Details can be found on: https://github.com/joaofcmota/cs-with-prior-information
"""
B, C, P, Q = b.shape
B, C, M, N = w.shape
n = M*N
m = P*Q
# Flatten the input: b and w
b = b.reshape(B, C, m, 1)
w = w.reshape(B, C, n, 1)
if DCT is not None: # Convert w in to Sparse Domain
w = DCT @ w
# For Skip connection across the L1L1-Solver Layer
if skip == 1:
w = -1.0 * w
beta = torch.as_tensor(1/beta, device=b.device)
b = b + A_down @ w
# Algorithm Parameters
rho = torch.as_tensor(1.0, device=b.device)
tau_rho = torch.as_tensor(10.0, device=b.device)
mu_rho = torch.as_tensor(2.0, device=b.device)
eps_prim = torch.as_tensor(1E-3, device=b.device)
eps_dual = torch.as_tensor(1E-3, device=b.device)
MAX_ITER = 1000
# Initialization/preallocate of local variables
aux1 = torch.zeros_like(b)
lam = torch.zeros_like(w)
x = torch.clone(w)
y = torch.clone(w)
z = torch.zeros_like(x)
Az_minus_b = torch.zeros_like(b)
v = torch.zeros_like(y)
rhow = torch.zeros_like(w)
w_pos = torch.zeros(w.shape, dtype=torch.bool, device=w.device)
not_w_pos = torch.zeros_like(w_pos)
indices = torch.zeros_like(w_pos)
zero = torch.as_tensor(0.0, device=b.device)
one = torch.as_tensor(1.0, device=b.device)
# Main Iteration
for k in range(MAX_ITER):
#======================================== x-minimization ================================#
v = lam - rho * y
rhow = rho * w
w_pos = (w >= 0.0)
# Components for which w_i >= 0
indices = w_pos & (v < -rhow - beta - one)
x[indices] = (-beta - one - v[indices])/rho
indices = w_pos & (-rhow - beta - one <= v) & (v <= -rhow + beta - one)
x[indices] = w[indices]
indices = w_pos & (-rhow + beta - one < v) & (v < beta - one)
x[indices] = (beta - one - v[indices])/rho
indices = w_pos & (beta - one <= v) & (v <= beta + one)
x[indices] = zero
indices = w_pos & (v > beta + one)
x[indices] = (beta + one - v[indices]) / rho
# Components for which w_i < 0
not_w_pos = ~w_pos
indices = not_w_pos & (v < -beta - one)
x[indices] = (-beta - one - v[indices])/rho
indices = not_w_pos & (-beta - one <= v) & (v <= -beta + one)
x[indices] = zero
indices = not_w_pos & (-beta + one < v) & (v < -rhow - beta + one)
x[indices] = (-beta + one - v[indices]) / rho
indices = not_w_pos & (-rhow - beta + one <= v) & (v <= -rhow + beta + one)
x[indices] = w[indices]
indices = not_w_pos & (v > -rhow + beta + one)
x[indices] = (beta + one - v[indices]) / rho
#=========================== y-minimization ===========================#
y_prev = torch.clone(y)
z = (lam + (rho*x)) / rho
Az_minus_b = (A_down @ z) - b
torch.linalg.solve(A_AT, Az_minus_b, out=aux1) # A_AT is Full Rank Matrix
y = z - A_up @ aux1
#===================== Update dual variable ==============================#
r_prim = x - y # primal residual
lam = lam + rho * r_prim
s_dual = -rho * (y - y_prev) # dual residual
#============================= rho parameter adjustment ======================#
r_prim_norm = torch.linalg.norm(r_prim)
s_dual_norm = torch.linalg.norm(s_dual)
if r_prim_norm > tau_rho * s_dual_norm:
rho = mu_rho*rho
elif s_dual_norm > tau_rho * r_prim_norm:
rho = rho/mu_rho
if (r_prim_norm < eps_prim) and (s_dual_norm < eps_dual):
break
#===================================== End of Loop ======================================= #
x_opt = torch.clone(x)
if IDCT is not None: # Convert back to Image Domain
x_opt = IDCT @ x_opt
x_opt = x_opt.reshape(B, C, M, N)
return x_opt