Hi Tiance an J!
My workstation gpu (and cpu) can accommodate a seq_len
of a few
thousand running the loop-free version.
A few more comments about performance: The sync_EMA()
version
is really not an apples-to-apples comparison as it doesn’t reproduce
the computation you originally posted. (If you don’t need your original
exponential-moving-average formulation, there are lots of ways you
might speed things up.)
I’ve run some more systematic timings. The precomputed-weight-matrix
version outperforms sync_EMA()
on my gpu for seq_len
up to 10,000
(about the largest that fits on the gpu) and for seq_len
up to 5,000 on
the same machine’s cpu.
Here is an expanded version of J’s timing code:
import torch
print (torch.__version__)
import time
_ = torch.manual_seed (2022)
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
def sync_EMA(x, period: int): # x size of (batch, channels, sequence)
xy = torch.arange(start=1, end=x.size()[-1] + 1, device = device)
xy_mask = xy > period
xy[xy_mask] = period
xz = torch.arange(start=0, end=period, device = device)
xy = torch.cat([xy.unsqueeze(0)] * xz.size()[0])
xz = torch.stack([xz] * xy.size()[1], dim=1)
w = 2 * (xy - xz) / (xy * (xy + 1))
w = w.T
w = w.unsqueeze(0).unsqueeze(0)
xx = torch.stack(
[torch.cat([torch.zeros((x.size()[0], x.size()[1], p,), device = device), x[..., :-p]], dim=-1) for p in range(1, period)],
dim=-1)
xx = torch.cat([x.unsqueeze(-1), xx], dim=-1)
xx = xx * w
return torch.sum(xx, dim=-1)
def sync_EMA_2(x, period: int): # x size of (batch, channels, sequence)
y = torch.zeros_like(x)
mult = 2 / (period + 1)
for i in range(x.size()[-1]):
if i == 0:
y[:, :, i] = x[:, :, i]
else:
y[:, :, i] = x[:, :, i] * mult + y[:, :, i - 1] * (1 - mult)
return y
def sync_EMA_3(x, period: int): # x size of (batch, channels, sequence)
n = x.shape[-1]
# w=1/period
w = 2 / (period + 1)
p = (1 - w) ** torch.arange(n + 1, device = device) # powers of (1 - weight)
v = p.repeat(n).reshape(n + 1, n)[:-1].triu() # length x length matrix of powers
yB = w * x @ v + (1 - w) * x[..., 0, None] @ v[None, 0]
return yB
periodPrecompute = None
nPrecompute = None
vPrecompute = None
def sync_EMA_3b (x, period: int):
global periodPrecompute, nPrecompute, vPrecompute
n = x.shape[-1]
w = 2 / (period + 1)
if vPrecompute is None or period != periodPrecompute or n != nPrecompute:
p = (1 - w) ** torch.arange(n + 1, device = device)
v = p.repeat(n).reshape(n + 1, n)[:-1].triu()
vPrecompute = v
nPrecompute = n
periodPrecompute = period
yB = w * x @ vPrecompute + (1 - w) * x[..., 0, None] @ vPrecompute[None, 0]
return yB
nWarm = 3
nTime = 10
for seq_len in (1000, 2000, 5000, 10000):
price_rand = torch.randn (seq_len, device = device)
base = 1000
seq_sine_period=30
y = (base/10+price_rand*20)*torch.sin (torch.arange (seq_len, device = device)/seq_sine_period)+base
period = 21
y=y.view(1,1,-1)
# warmup
for i in range (nWarm):
_ = sync_EMA (y, period)
_ = sync_EMA_2 (y, period)
_ = sync_EMA_3 (y, period)
_ = sync_EMA_3b (y, period)
# timings
print ('timings, seq_len:', seq_len)
print ('device:', device)
if device == 'cuda':
torch.cuda.synchronize()
starttime1 = time.time()
for i in range (nTime):
EMA1 = sync_EMA (y, period)
if device == 'cuda':
torch.cuda.synchronize()
print ('sync_EMA: ', (time.time() - starttime1) / nTime)
if device == 'cuda':
torch.cuda.synchronize()
starttime2 = time.time()
for i in range (nTime):
EMA2 = sync_EMA_2 (y, period)
if device == 'cuda':
torch.cuda.synchronize()
print ('sync_EMA_2: ', (time.time() - starttime2) / nTime)
if device == 'cuda':
torch.cuda.synchronize()
starttime3 = time.time()
for i in range (nTime):
EMA3 = sync_EMA_3 (y, period)
if device == 'cuda':
torch.cuda.synchronize()
print ('sync_EMA_3: ', (time.time() - starttime3) / nTime)
if device == 'cuda':
torch.cuda.synchronize()
starttime3b = time.time()
for i in range (nTime):
EMA3b = sync_EMA_3b (y, period)
if device == 'cuda':
torch.cuda.synchronize()
print ('sync_EMA_3b:', (time.time() - starttime3b) / nTime)
print ('check EMA1: ', torch.allclose (EMA1, EMA2))
print ('check EMA3: ', torch.allclose (EMA3, EMA2))
print ('check EMA3b:', torch.allclose (EMA3b, EMA2))
print ('check EMA3b, EMA3: ', torch.equal (EMA3b, EMA3))
print ('EMA1, EMA2 max diff: ', (EMA1 - EMA2).abs().max())
Here are the cpu timings:
1.12.0
timings, seq_len: 1000
device: cpu
sync_EMA: 0.0014473199844360352
sync_EMA_2: 0.05421710014343262
sync_EMA_3: 0.0019309282302856444
sync_EMA_3b: 0.00019500255584716796
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.3462)
timings, seq_len: 2000
device: cpu
sync_EMA: 0.0019614458084106444
sync_EMA_2: 0.10763685703277588
sync_EMA_3: 0.005632424354553222
sync_EMA_3b: 0.001221442222595215
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.2375)
timings, seq_len: 5000
device: cpu
sync_EMA: 0.009285163879394532
sync_EMA_2: 0.26809670925140383
sync_EMA_3: 0.039087104797363284
sync_EMA_3b: 0.006497597694396973
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.2563)
timings, seq_len: 10000
device: cpu
sync_EMA: 0.008062648773193359
sync_EMA_2: 0.5337584495544434
sync_EMA_3: 0.14901502132415773
sync_EMA_3b: 0.024390339851379395
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.8373)
And the gpu timings:
1.12.0
timings, seq_len: 1000
device: cuda
sync_EMA: 0.001915717124938965
sync_EMA_2: 0.10515487194061279
sync_EMA_3: 0.00032756328582763674
sync_EMA_3b: 0.00014753341674804686
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.0728, device='cuda:0')
timings, seq_len: 2000
device: cuda
sync_EMA: 0.0024464607238769533
sync_EMA_2: 0.20999035835266114
sync_EMA_3: 0.001055598258972168
sync_EMA_3b: 0.00021555423736572266
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.3566, device='cuda:0')
timings, seq_len: 5000
device: cuda
sync_EMA: 0.004022955894470215
sync_EMA_2: 0.5285399198532105
sync_EMA_3: 0.005906414985656738
sync_EMA_3b: 0.0010313272476196289
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.3903, device='cuda:0')
timings, seq_len: 10000
device: cuda
sync_EMA: 0.006554818153381348
sync_EMA_2: 1.056663703918457
sync_EMA_3: 0.023943758010864256
sync_EMA_3b: 0.00549619197845459
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.7476, device='cuda:0')
Again, sync_EMA
(EMA1
) computes something different than your
original version (sync_EMA_2
/ EMA2
). Whether that matters would
depend on the details of your use case.
The precomputed-weight-matrix version (sync_EMA_3b
/ EMA3b
) both
agrees with your computation and outperforms the other versions for
rather large values of seq_len
, in spite of its seq_len**2
inefficiency.
This just underscores the general rule that python loops kill you in
performance relative to (python) loop-free tensor operations, especially
on a gpu.
(As an aside, if performance remains an issue for you, the seq_len**2
cost can be eliminated with rather more programming effort.)
Best.
K. Frank