Hi Zhi and Alex!
The short story is that I can reproduce your results.
I’ve also done some further experimentation.
The best I could come up with was to apply Alex’s torchscript suggestion
to your “w-b” version.
I also tried a Conv2d
version – even though it doesn’t have
BatchNorm2d
’s normalization step, it was also slower.
Note, torchscript gave a nice speedup with pytorch version 1.6.0, but
not with a version 1.8.0 nightly build.
I used your original problem – same-shaped tensors, and such – but
I modified the code to perform some additional tests and to tie out the
results of the various methods with one another.
Here is the script:
# cxb_time.py
# exec (open ('./cxb_time.py').read())
import torch
print ('torch.__version__ =', torch.__version__)
print ('torch.cuda.get_device_name (0) =', torch.cuda.get_device_name (0))
import time
# contents of torchscript file, jts.py:
# import torch
#
# @torch.jit.script
# def jtwb (w, b, x):
# return x * w.view (x.shape[0], x.shape[1], 1, 1) + b.view (x.shape[0], x.shape[1], 1, 1)
#
# @torch.jit.script
# def trmac (w, b, x):
# return ((w * x.transpose (1, 3)).add_ (b)).transpose (1, 3)
from jts import jtwb, trmac
N, C, H, W = 64, 256, 64, 64
repetitions = 100
# random weight and bias
w = torch.randn (C).cuda()
b = torch.randn (C).cuda()
# initialize bn, and cv identically to w-b
bn = torch.nn.BatchNorm2d(C, eps = 0.0).cuda()
cv = torch.nn.Conv2d (C, C, 1, groups = C).cuda()
with torch.no_grad():
_ = bn.weight.copy_ (w)
_ = bn.bias.copy_ (b)
_ = cv.weight.copy_ (w.reshape ([C, 1, 1, 1]))
_ = cv.bias.copy_ (b)
w0 = w.clone() # for trmac version
b0 = b.clone() # for trmac version
w = torch.nn.Parameter(w.expand(N, C).contiguous())
b = torch.nn.Parameter(b.expand(N, C).contiguous())
x = torch.randn(N, C, H, W).cuda()
# normalize x so that only weight and bias (gamma and beta) of batchnorm matter
x = (x - x.mean (dim = 0)) / x.std (dim = 0, unbiased = False)
# compute y five different ways
ywb = x * w.view(N, C, 1, 1) + b.view(N, C, 1, 1)
ybn = bn (x)
ycv = cv (x)
yjt = jtwb (w, b, x)
ytr = trmac (w0, b0, x)
# check that results for y are nearly equal
print ('verify that the different methods give the same result:')
print ('(ybn - ywb).abs().max() =', (ybn - ywb).abs().max().item())
print ('(ycv - ywb).abs().max() =', (ycv - ywb).abs().max().item())
print ('(yjt - ywb).abs().max() =', (yjt - ywb).abs().max().item())
print ('(ytr - ywb).abs().max() =', (ytr - ywb).abs().max().item())
ywb = None
ybn = None
ycv = None
yjt = None
ytr = None
print ('time the different methods, repetitions =', repetitions)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
y = x * w.view(N, C, 1, 1) + b.view(N, C, 1, 1)
torch.cuda.synchronize()
print ('w-b time:', (time.time() - t0) / repetitions)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
y = bn(x)
torch.cuda.synchronize()
print ('bn time:', (time.time() - t0) / repetitions)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
y = cv(x)
torch.cuda.synchronize()
print ('cv time:', (time.time() - t0) / repetitions)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
y = jtwb (w, b, x)
torch.cuda.synchronize()
print ('jtwb time:', (time.time() - t0) / repetitions)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
y = trmac (w0, b0, x)
torch.cuda.synchronize()
print ('jttr time:', (time.time() - t0) / repetitions)
Here is the pytorch version 1.6.0 output:
>>> exec (open ('./cxb_time.py').read())
torch.__version__ = 1.6.0
torch.cuda.get_device_name (0) = GeForce GTX 1050 Ti
verify that the different methods give the same result:
(ybn - ywb).abs().max() = 2.384185791015625e-07
(ycv - ywb).abs().max() = 9.5367431640625e-07
(yjt - ywb).abs().max() = 9.5367431640625e-07
(ytr - ywb).abs().max() = 0.0
time the different methods, repetitions = 100
w-b time: 0.010975425243377685
bn time: 0.008232767581939698
cv time: 0.009469003677368163
jtwb time: 0.005830504894256592
jttr time: 0.014244294166564942
And, for completeness, here is the 1.8.0.dev20201203 output:
>>> exec (open ('./cxb_time.py').read())
torch.__version__ = 1.8.0.dev20201203
torch.cuda.get_device_name (0) = GeForce GTX 1050 Ti
verify that the different methods give the same result:
(ybn - ywb).abs().max() = 4.76837158203125e-07
(ycv - ywb).abs().max() = 9.5367431640625e-07
(yjt - ywb).abs().max() = 0.0
(ytr - ywb).abs().max() = 0.0
time the different methods, repetitions = 100
w-b time: 0.011010963916778565
bn time: 0.008294436931610107
cv time: 0.009465277194976807
jtwb time: 0.011733412742614746
jttr time: 0.01428614854812622
The only difference is that with version 1.8.0.dev20201203 using jit
does not speed up your original “w-b” method. (Also notice, that the
1.6.0 jit “w-b” method differed at round-off error from the raw “w-b”
method, while the 1.8.0 jit “w-b” method exactly reproduced the raw
“w-b” results. It’s as if jit didn’t do anything.)
Best.
K. Frank