How to implement efficient element-wise 'ax+b' function

Hi Zhi and Alex!

The short story is that I can reproduce your results.

I’ve also done some further experimentation.

The best I could come up with was to apply Alex’s torchscript suggestion
to your “w-b” version.

I also tried a Conv2d version – even though it doesn’t have
BatchNorm2d’s normalization step, it was also slower.

Note, torchscript gave a nice speedup with pytorch version 1.6.0, but
not with a version 1.8.0 nightly build.

I used your original problem – same-shaped tensors, and such – but
I modified the code to perform some additional tests and to tie out the
results of the various methods with one another.

Here is the script:

# cxb_time.py

# exec (open ('./cxb_time.py').read())

import torch
print ('torch.__version__ =', torch.__version__)
print ('torch.cuda.get_device_name (0) =', torch.cuda.get_device_name (0))

import time

# contents of torchscript file, jts.py:

# import torch
#
# @torch.jit.script
# def jtwb (w, b, x):
#     return x * w.view (x.shape[0], x.shape[1], 1, 1) + b.view (x.shape[0], x.shape[1], 1, 1)
#
# @torch.jit.script
# def trmac (w, b, x):
#     return ((w * x.transpose (1, 3)).add_ (b)).transpose (1, 3)

from jts import jtwb, trmac

N, C, H, W = 64, 256, 64, 64
repetitions = 100

# random weight and bias
w = torch.randn (C).cuda()
b = torch.randn (C).cuda()

# initialize bn, and cv identically to w-b
bn = torch.nn.BatchNorm2d(C, eps = 0.0).cuda()
cv = torch.nn.Conv2d (C, C, 1, groups = C).cuda()
with torch.no_grad():
    _ = bn.weight.copy_ (w)
    _ = bn.bias.copy_ (b)
    _ = cv.weight.copy_ (w.reshape ([C, 1, 1, 1]))
    _ = cv.bias.copy_ (b)

w0 = w.clone()  # for trmac version
b0 = b.clone()  # for trmac version
w = torch.nn.Parameter(w.expand(N, C).contiguous())
b = torch.nn.Parameter(b.expand(N, C).contiguous())

x = torch.randn(N, C, H, W).cuda()
# normalize x so that only weight and bias (gamma and beta) of batchnorm matter
x = (x - x.mean (dim = 0)) / x.std (dim = 0, unbiased = False)

# compute y five different ways
ywb = x * w.view(N, C, 1, 1) + b.view(N, C, 1, 1)
ybn = bn (x)
ycv = cv (x)
yjt = jtwb (w, b, x)
ytr = trmac (w0, b0, x)

# check that results for y are nearly equal
print ('verify that the different methods give the same result:')
print ('(ybn - ywb).abs().max() =', (ybn - ywb).abs().max().item())
print ('(ycv - ywb).abs().max() =', (ycv - ywb).abs().max().item())
print ('(yjt - ywb).abs().max() =', (yjt - ywb).abs().max().item())
print ('(ytr - ywb).abs().max() =', (ytr - ywb).abs().max().item())

ywb = None
ybn = None
ycv = None
yjt = None
ytr = None

print ('time the different methods, repetitions =', repetitions)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
    y = x * w.view(N, C, 1, 1) + b.view(N, C, 1, 1)

torch.cuda.synchronize()
print ('w-b time:', (time.time() - t0) / repetitions)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
    y = bn(x)

torch.cuda.synchronize()
print ('bn time:', (time.time() - t0) / repetitions)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
    y = cv(x)

torch.cuda.synchronize()
print ('cv time:', (time.time() - t0) / repetitions)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
    y = jtwb (w, b, x)

torch.cuda.synchronize()
print ('jtwb time:', (time.time() - t0) / repetitions)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
    y = trmac (w0, b0, x)

torch.cuda.synchronize()
print ('jttr time:', (time.time() - t0) / repetitions)

Here is the pytorch version 1.6.0 output:

>>> exec (open ('./cxb_time.py').read())
torch.__version__ = 1.6.0
torch.cuda.get_device_name (0) = GeForce GTX 1050 Ti
verify that the different methods give the same result:
(ybn - ywb).abs().max() = 2.384185791015625e-07
(ycv - ywb).abs().max() = 9.5367431640625e-07
(yjt - ywb).abs().max() = 9.5367431640625e-07
(ytr - ywb).abs().max() = 0.0
time the different methods, repetitions = 100
w-b time: 0.010975425243377685
bn time: 0.008232767581939698
cv time: 0.009469003677368163
jtwb time: 0.005830504894256592
jttr time: 0.014244294166564942

And, for completeness, here is the 1.8.0.dev20201203 output:

>>> exec (open ('./cxb_time.py').read())
torch.__version__ = 1.8.0.dev20201203
torch.cuda.get_device_name (0) = GeForce GTX 1050 Ti
verify that the different methods give the same result:
(ybn - ywb).abs().max() = 4.76837158203125e-07
(ycv - ywb).abs().max() = 9.5367431640625e-07
(yjt - ywb).abs().max() = 0.0
(ytr - ywb).abs().max() = 0.0
time the different methods, repetitions = 100
w-b time: 0.011010963916778565
bn time: 0.008294436931610107
cv time: 0.009465277194976807
jtwb time: 0.011733412742614746
jttr time: 0.01428614854812622

The only difference is that with version 1.8.0.dev20201203 using jit
does not speed up your original “w-b” method. (Also notice, that the
1.6.0 jit “w-b” method differed at round-off error from the raw “w-b”
method, while the 1.8.0 jit “w-b” method exactly reproduced the raw
“w-b” results. It’s as if jit didn’t do anything.)

Best.

K. Frank