Hi Tiance an J!

My workstation gpu (and cpu) can accommodate a `seq_len`

of a few

thousand running the loop-free version.

A few more comments about performance: The `sync_EMA()`

version

is really not an apples-to-apples comparison as it doesn’t reproduce

the computation you originally posted. (If you don’t need your original

exponential-moving-average formulation, there are lots of ways you

might speed things up.)

I’ve run some more systematic timings. The precomputed-weight-matrix

version outperforms `sync_EMA()`

on my gpu for `seq_len`

up to 10,000

(about the largest that fits on the gpu) and for `seq_len`

up to 5,000 on

the same machine’s cpu.

Here is an expanded version of J’s timing code:

```
import torch
print (torch.__version__)
import time
_ = torch.manual_seed (2022)
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
def sync_EMA(x, period: int): # x size of (batch, channels, sequence)
xy = torch.arange(start=1, end=x.size()[-1] + 1, device = device)
xy_mask = xy > period
xy[xy_mask] = period
xz = torch.arange(start=0, end=period, device = device)
xy = torch.cat([xy.unsqueeze(0)] * xz.size()[0])
xz = torch.stack([xz] * xy.size()[1], dim=1)
w = 2 * (xy - xz) / (xy * (xy + 1))
w = w.T
w = w.unsqueeze(0).unsqueeze(0)
xx = torch.stack(
[torch.cat([torch.zeros((x.size()[0], x.size()[1], p,), device = device), x[..., :-p]], dim=-1) for p in range(1, period)],
dim=-1)
xx = torch.cat([x.unsqueeze(-1), xx], dim=-1)
xx = xx * w
return torch.sum(xx, dim=-1)
def sync_EMA_2(x, period: int): # x size of (batch, channels, sequence)
y = torch.zeros_like(x)
mult = 2 / (period + 1)
for i in range(x.size()[-1]):
if i == 0:
y[:, :, i] = x[:, :, i]
else:
y[:, :, i] = x[:, :, i] * mult + y[:, :, i - 1] * (1 - mult)
return y
def sync_EMA_3(x, period: int): # x size of (batch, channels, sequence)
n = x.shape[-1]
# w=1/period
w = 2 / (period + 1)
p = (1 - w) ** torch.arange(n + 1, device = device) # powers of (1 - weight)
v = p.repeat(n).reshape(n + 1, n)[:-1].triu() # length x length matrix of powers
yB = w * x @ v + (1 - w) * x[..., 0, None] @ v[None, 0]
return yB
periodPrecompute = None
nPrecompute = None
vPrecompute = None
def sync_EMA_3b (x, period: int):
global periodPrecompute, nPrecompute, vPrecompute
n = x.shape[-1]
w = 2 / (period + 1)
if vPrecompute is None or period != periodPrecompute or n != nPrecompute:
p = (1 - w) ** torch.arange(n + 1, device = device)
v = p.repeat(n).reshape(n + 1, n)[:-1].triu()
vPrecompute = v
nPrecompute = n
periodPrecompute = period
yB = w * x @ vPrecompute + (1 - w) * x[..., 0, None] @ vPrecompute[None, 0]
return yB
nWarm = 3
nTime = 10
for seq_len in (1000, 2000, 5000, 10000):
price_rand = torch.randn (seq_len, device = device)
base = 1000
seq_sine_period=30
y = (base/10+price_rand*20)*torch.sin (torch.arange (seq_len, device = device)/seq_sine_period)+base
period = 21
y=y.view(1,1,-1)
# warmup
for i in range (nWarm):
_ = sync_EMA (y, period)
_ = sync_EMA_2 (y, period)
_ = sync_EMA_3 (y, period)
_ = sync_EMA_3b (y, period)
# timings
print ('timings, seq_len:', seq_len)
print ('device:', device)
if device == 'cuda':
torch.cuda.synchronize()
starttime1 = time.time()
for i in range (nTime):
EMA1 = sync_EMA (y, period)
if device == 'cuda':
torch.cuda.synchronize()
print ('sync_EMA: ', (time.time() - starttime1) / nTime)
if device == 'cuda':
torch.cuda.synchronize()
starttime2 = time.time()
for i in range (nTime):
EMA2 = sync_EMA_2 (y, period)
if device == 'cuda':
torch.cuda.synchronize()
print ('sync_EMA_2: ', (time.time() - starttime2) / nTime)
if device == 'cuda':
torch.cuda.synchronize()
starttime3 = time.time()
for i in range (nTime):
EMA3 = sync_EMA_3 (y, period)
if device == 'cuda':
torch.cuda.synchronize()
print ('sync_EMA_3: ', (time.time() - starttime3) / nTime)
if device == 'cuda':
torch.cuda.synchronize()
starttime3b = time.time()
for i in range (nTime):
EMA3b = sync_EMA_3b (y, period)
if device == 'cuda':
torch.cuda.synchronize()
print ('sync_EMA_3b:', (time.time() - starttime3b) / nTime)
print ('check EMA1: ', torch.allclose (EMA1, EMA2))
print ('check EMA3: ', torch.allclose (EMA3, EMA2))
print ('check EMA3b:', torch.allclose (EMA3b, EMA2))
print ('check EMA3b, EMA3: ', torch.equal (EMA3b, EMA3))
print ('EMA1, EMA2 max diff: ', (EMA1 - EMA2).abs().max())
```

Here are the cpu timings:

```
1.12.0
timings, seq_len: 1000
device: cpu
sync_EMA: 0.0014473199844360352
sync_EMA_2: 0.05421710014343262
sync_EMA_3: 0.0019309282302856444
sync_EMA_3b: 0.00019500255584716796
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.3462)
timings, seq_len: 2000
device: cpu
sync_EMA: 0.0019614458084106444
sync_EMA_2: 0.10763685703277588
sync_EMA_3: 0.005632424354553222
sync_EMA_3b: 0.001221442222595215
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.2375)
timings, seq_len: 5000
device: cpu
sync_EMA: 0.009285163879394532
sync_EMA_2: 0.26809670925140383
sync_EMA_3: 0.039087104797363284
sync_EMA_3b: 0.006497597694396973
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.2563)
timings, seq_len: 10000
device: cpu
sync_EMA: 0.008062648773193359
sync_EMA_2: 0.5337584495544434
sync_EMA_3: 0.14901502132415773
sync_EMA_3b: 0.024390339851379395
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.8373)
```

And the gpu timings:

```
1.12.0
timings, seq_len: 1000
device: cuda
sync_EMA: 0.001915717124938965
sync_EMA_2: 0.10515487194061279
sync_EMA_3: 0.00032756328582763674
sync_EMA_3b: 0.00014753341674804686
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.0728, device='cuda:0')
timings, seq_len: 2000
device: cuda
sync_EMA: 0.0024464607238769533
sync_EMA_2: 0.20999035835266114
sync_EMA_3: 0.001055598258972168
sync_EMA_3b: 0.00021555423736572266
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.3566, device='cuda:0')
timings, seq_len: 5000
device: cuda
sync_EMA: 0.004022955894470215
sync_EMA_2: 0.5285399198532105
sync_EMA_3: 0.005906414985656738
sync_EMA_3b: 0.0010313272476196289
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.3903, device='cuda:0')
timings, seq_len: 10000
device: cuda
sync_EMA: 0.006554818153381348
sync_EMA_2: 1.056663703918457
sync_EMA_3: 0.023943758010864256
sync_EMA_3b: 0.00549619197845459
check EMA1: False
check EMA3: True
check EMA3b: True
check EMA3b, EMA3: True
EMA1, EMA2 max diff: tensor(11.7476, device='cuda:0')
```

Again, `sync_EMA`

(`EMA1`

) computes something different than your

original version (`sync_EMA_2`

/ `EMA2`

). Whether that matters would

depend on the details of your use case.

The precomputed-weight-matrix version (`sync_EMA_3b`

/ `EMA3b`

) both

agrees with your computation and outperforms the other versions for

rather large values of `seq_len`

, in spite of its `seq_len**2`

inefficiency.

This just underscores the general rule that python loops kill you in

performance relative to (python) loop-free tensor operations, especially

on a gpu.

(As an aside, if performance remains an issue for you, the `seq_len**2`

cost can be eliminated with rather more programming effort.)

Best.

K. Frank