Parallelize & accelerate loops of tensor additions

Background:

I am working on a program that first shifts the different channels of a tensor along the “column” dimension with different distances, and then performs a summation along the “channel” dimension to merge the different dimensions into one. Specifically, given a tensor x of size (B,C,H,W) and step size S, where B, C, H, W represent the batch size, channel number, height, and width, respectively, the i-th channel of x is shifted by distance (i-1)*S, and then the C channels are summed into one.

Here is an 1D toy example.
Assume that I have a 3-channel tensor x as

x = torch.tensor(
[[1,1,1],
[2,2,2],
[3,3,3]]
)

Now I set the step size as 1, and then perform a shift on the tensor as

x_shifted = torch.tensor(
[[1,1,1,0,0],
[0,2,2,2,0],
[0,0,3,3,3]]
)

Here, the first channel is shifted by distance 0, the second channel is shifted by distance 1, and the third channel is shifted by distance 2.
Finally, all the three channels are summed and merged into one as

y = torch.tensor(
[[1,3,6,5,3]]
)

Question:

I have implemented the original process w.r.t. 2D image tensors in the following code:

import torch
import torch.nn.functional as F
from time import time

#############################################
# Parameters
#############################################

B = 16
C = 28
H = 256
W = 256
S = 2
T = 1000
device = torch.device('cuda')

seed = 2023
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

#############################################
# Method 1
#############################################

alpha = torch.zeros(B, 1, 1, W+(C-1)*S, device=device)
for i in range(C):
    alpha[..., (i*S):(i*S+W)] += 1

def A(x, mask):
    z = x * mask
    y = torch.zeros(B, 1, H, W+(C-1)*S, device=x.device)
    for i in range(C):
        y[..., (i*S):(i*S+W)] += z[:, (i):(i+1)]
    return y

def A_pinv(y, mask):
    z = y / alpha.to(y.device)
    x = torch.cat([z[..., (i*S):(i*S+W)] for i in range(C)], dim=1) / mask
    return x

#############################################
# Method 2
#############################################

kernel = torch.zeros(1, C, 1, (C-1)*S+1, device=device)
for i in range(C):
    kernel[:, C-i-1, :, i*S] = 1

def A_fast(x, mask):
    return F.conv2d(x * mask, kernel.to(x.device), padding=(0, (C-1)*S))

def A_pinv_fast(y, mask):
    return F.conv_transpose2d(y / alpha.to(y.device), kernel, padding=(0, (C-1)*S)) / mask

#############################################
# Test 1
#############################################
start_time = time()
MAE = 0
for i in range(T):
    x = torch.rand(B, C, H, W, device=device)
    mask = torch.rand(1, 1, H, W, device=device)
    mask[mask == 0] = 1e-12
    y = A(x, mask)
    x_init = A_pinv(y, mask)
    y_init = A(x_init, mask)
    MAE += (y_init - y).abs().mean().item()
MAE /= T
end_time = time()
print('---')
print('Test 1')
print('Running Time:', end_time - start_time)
print('MAE:', MAE)

#############################################
# Test 2
#############################################
start_time = time()
MAE = 0
for i in range(T):
    x = torch.rand(B, C, H, W, device=device)
    mask = torch.rand(1, 1, H, W, device=device)
    mask[mask == 0] = 1e-12
    y = A_fast(x, mask)
    x_init = A_pinv_fast(y, mask)
    y_init = A_fast(x_init, mask)
    MAE += (y_init - y).abs().mean().item()
MAE /= T
end_time = time()
print('---')
print('Test 2')
print('Running Time:', end_time - start_time)
print('MAE:', MAE)

Here, Method 1 implements the process with a for loop, while I believe that Method 2 implements the process equivalently by using a 2D convolution operation.

To be more specific, functions A and A_pinv realize the forwarding compress process and its “pseudo-inverse”, respectively. Their “fast” versions in Method 2 are expected to be faster with a parallelized implementation.

However, when I run the code, I find that the Method 1 is still much faster than the Method 2 with large speed leading. I want to ask that:

Can we effectively accelerate the Method 1? To be more specific, I wonder if we can parallelize the for loops, to make the “Shift+Summation” process faster?

I am still thinking and testing…


Thank you for reading my long post :slight_smile: