How to implement efficient element-wise 'ax+b' function

For some reason, I want to implement the scaling and bias part of BatchNorm, w/o the normalizing part. I resort to broadcasting to implement this element-wise operation. But I found such implementation is slower than the original nn.BathNorm implementation, whose theoretical FLOPs are several times more than the former.

In [49]: import torch
    ...: import time
    ...:
    ...: N, C, H, W = 64, 256, 64, 64
    ...: repetitions = 100
    ...:
    ...: x = torch.randn(N, C, H, W).cuda()
    ...: w = nn.Parameter(torch.randn(C).cuda().expand(N, C).contiguous())
    ...: b = nn.Parameter(torch.randn(C).cuda().expand(N, C).contiguous())
    ...:
    ...: t0 = time.time()
    ...: for _ in range(repetitions):
    ...:     y = x * w.view(N, C, 1, 1) + b.view(N, C, 1, 1)
    ...: torch.cuda.synchronize()
    ...: print ((time.time() - t0) / repetitions)
0.0012868165969848633                   
In [41]: import torch
    ...: import time
    ...:
    ...: N, C, H, W = 64, 256, 64, 64
    ...: repetitions = 100
    ...:
    ...: x = torch.randn(N, C, H, W).cuda()
    ...: #w = torch.randn(C).cuda().expand(N, C).contiguous()
    ...: bn = torch.nn.BatchNorm2d(C).cuda()
    ...:
    ...: t0 = time.time()
    ...: for _ in range(repetitions):
    ...:     y = bn(x)
    ...: torch.cuda.synchronize()
    ...: print ((time.time() - t0) / repetitions)
0.0009896183013916016

So what is the best way to implement the scaling and bias function?When it comes to the training phase, time per batch of the broadcasting impl. is almost twice as the original BN impl. in my model backboned by a ResNet18.

I work on Pytorch1.8.0a0+79b9c03 comiled with cuda11.1 and cudnn 8005.

try these:

@jit.script
def affine(x,m,a):
  return x*m+a

def affine(x,m,a):
  return (x*m).add_(a)

also, view() overhead in your test may be non-negligible.

Hi Zhi and Alex!

The short story is that I can reproduce your results.

I’ve also done some further experimentation.

The best I could come up with was to apply Alex’s torchscript suggestion
to your “w-b” version.

I also tried a Conv2d version – even though it doesn’t have
BatchNorm2d's normalization step, it was also slower.

Note, torchscript gave a nice speedup with pytorch version 1.6.0, but
not with a version 1.8.0 nightly build.

I used your original problem – same-shaped tensors, and such – but
I modified the code to perform some additional tests and to tie out the
results of the various methods with one another.

Here is the script:

# cxb_time.py

# exec (open ('./cxb_time.py').read())

import torch
print ('torch.__version__ =', torch.__version__)
print ('torch.cuda.get_device_name (0) =', torch.cuda.get_device_name (0))

import time

# contents of torchscript file, jts.py:

# import torch
#
# @torch.jit.script
# def jtwb (w, b, x):
#     return x * w.view (x.shape[0], x.shape[1], 1, 1) + b.view (x.shape[0], x.shape[1], 1, 1)
#
# @torch.jit.script
# def trmac (w, b, x):
#     return ((w * x.transpose (1, 3)).add_ (b)).transpose (1, 3)

from jts import jtwb, trmac

N, C, H, W = 64, 256, 64, 64
repetitions = 100

# random weight and bias
w = torch.randn (C).cuda()
b = torch.randn (C).cuda()

# initialize bn, and cv identically to w-b
bn = torch.nn.BatchNorm2d(C, eps = 0.0).cuda()
cv = torch.nn.Conv2d (C, C, 1, groups = C).cuda()
with torch.no_grad():
    _ = bn.weight.copy_ (w)
    _ = bn.bias.copy_ (b)
    _ = cv.weight.copy_ (w.reshape ([C, 1, 1, 1]))
    _ = cv.bias.copy_ (b)

w0 = w.clone()  # for trmac version
b0 = b.clone()  # for trmac version
w = torch.nn.Parameter(w.expand(N, C).contiguous())
b = torch.nn.Parameter(b.expand(N, C).contiguous())

x = torch.randn(N, C, H, W).cuda()
# normalize x so that only weight and bias (gamma and beta) of batchnorm matter
x = (x - x.mean (dim = 0)) / x.std (dim = 0, unbiased = False)

# compute y five different ways
ywb = x * w.view(N, C, 1, 1) + b.view(N, C, 1, 1)
ybn = bn (x)
ycv = cv (x)
yjt = jtwb (w, b, x)
ytr = trmac (w0, b0, x)

# check that results for y are nearly equal
print ('verify that the different methods give the same result:')
print ('(ybn - ywb).abs().max() =', (ybn - ywb).abs().max().item())
print ('(ycv - ywb).abs().max() =', (ycv - ywb).abs().max().item())
print ('(yjt - ywb).abs().max() =', (yjt - ywb).abs().max().item())
print ('(ytr - ywb).abs().max() =', (ytr - ywb).abs().max().item())

ywb = None
ybn = None
ycv = None
yjt = None
ytr = None

print ('time the different methods, repetitions =', repetitions)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
    y = x * w.view(N, C, 1, 1) + b.view(N, C, 1, 1)

torch.cuda.synchronize()
print ('w-b time:', (time.time() - t0) / repetitions)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
    y = bn(x)

torch.cuda.synchronize()
print ('bn time:', (time.time() - t0) / repetitions)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
    y = cv(x)

torch.cuda.synchronize()
print ('cv time:', (time.time() - t0) / repetitions)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
    y = jtwb (w, b, x)

torch.cuda.synchronize()
print ('jtwb time:', (time.time() - t0) / repetitions)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(repetitions):
    y = trmac (w0, b0, x)

torch.cuda.synchronize()
print ('jttr time:', (time.time() - t0) / repetitions)

Here is the pytorch version 1.6.0 output:

>>> exec (open ('./cxb_time.py').read())
torch.__version__ = 1.6.0
torch.cuda.get_device_name (0) = GeForce GTX 1050 Ti
verify that the different methods give the same result:
(ybn - ywb).abs().max() = 2.384185791015625e-07
(ycv - ywb).abs().max() = 9.5367431640625e-07
(yjt - ywb).abs().max() = 9.5367431640625e-07
(ytr - ywb).abs().max() = 0.0
time the different methods, repetitions = 100
w-b time: 0.010975425243377685
bn time: 0.008232767581939698
cv time: 0.009469003677368163
jtwb time: 0.005830504894256592
jttr time: 0.014244294166564942

And, for completeness, here is the 1.8.0.dev20201203 output:

>>> exec (open ('./cxb_time.py').read())
torch.__version__ = 1.8.0.dev20201203
torch.cuda.get_device_name (0) = GeForce GTX 1050 Ti
verify that the different methods give the same result:
(ybn - ywb).abs().max() = 4.76837158203125e-07
(ycv - ywb).abs().max() = 9.5367431640625e-07
(yjt - ywb).abs().max() = 0.0
(ytr - ywb).abs().max() = 0.0
time the different methods, repetitions = 100
w-b time: 0.011010963916778565
bn time: 0.008294436931610107
cv time: 0.009465277194976807
jtwb time: 0.011733412742614746
jttr time: 0.01428614854812622

The only difference is that with version 1.8.0.dev20201203 using jit
does not speed up your original “w-b” method. (Also notice, that the
1.6.0 jit “w-b” method differed at round-off error from the raw “w-b”
method, while the 1.8.0 jit “w-b” method exactly reproduced the raw
“w-b” results. It’s as if jit didn’t do anything.)

Best.

K. Frank

You need to do two warmup iterations with 1.7+ JIT (profiling & codegen)

Hi Alex!

I added the following to my timing script:

# try "warming up" jtwb
print ('call jtwb() twice to "warm up"')
yjt1 = jtwb (w, b, x)
yjt2 = jtwb (w, b, x)
yjt1 = None
yjt2 = None

(So there are now three calls to jtwb() before the actual timing loop.)

This changed the results for the jtwb timing only modestly:

jtwb time: 0.010964112281799316
(was:      0.011733412742614746)

And yjt (the result of the jtwb() call) still agrees exactly with ywb.

Best.

K. Frank

Ok, it appears fusion optimization is fragile

  1. “from jts import *” breaks it, “import jts” works
  2. creating views inside jitted function breaks it

You can check this with torch.jit.last_executed_optimized_graph(), after warmup this graph should contain TensorExprGroup (or FusionGroup with legacy jit) block.

for the moment, it can be better to use legacy fuser:

torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)

Hi Alex!
I think the fusion optimization of torch.jit only works on inputs with fixed size that it is warmed up with (Actually I found feeding tensors with another size doesn’t work.). Is my understanding right?
If I want to apply the optimized function to tensors with different sizes(e.g. #channels of feature map are increasing as going deep with a CNN), is there a way to implement it elegantly?

This puzzles me a bit too - new codegen creates over-specialized kernels, hardcoding tensor sizes. However, my understanding is, it will compile multiple times for different sizes, if it sees the same shape[s] twice. There is torch._C._jit_set_bailout_depth(n) option to control this, it is 20 by default.

Simpler alternative is to use legacy jit, it produces more generic code.