For loop slows training down

I’m trying to do an exponential moving average. Shape of x is (batch size, channels, length)

y = torch.empty_like(x)
y[:, :, 0] = x[:, :, 0]
for t in range(1, y.shape[-1]):
     y[:, :, t] = 0.99 * y[:, :, t-1] + 0.01 * x[:, :, t]

But this code seems to run very slowly. Any suggestions for optimizing it? Thanks.

Turning EMA into an asynchronous operation is not exactly trivial. However, it is possible.

You can try this out:

import torch

def sync_EMA(x, period: int):  # x tensor size of (batch, channels, sequence)
    xy = torch.arange(start=1, end=x.size()[-1] + 1)
    xy_mask = xy > period
    xy[xy_mask] = period
    xz = torch.arange(end=period)

    xy = torch.cat([xy.unsqueeze(0)] * xz.size()[0])
    xz = torch.stack([xz] * xy.size()[1], dim=1)

    w = 2 * (xy - xz) / (xy * (xy + 1))
    w = w.T
    w = w.unsqueeze(0).unsqueeze(0)
    xx = torch.stack(
        [torch.cat([torch.zeros((x.size()[0], x.size()[1], p)), x[..., :-p]], dim=-1) for p in range(1, period)],
        dim=-1)
    xx = torch.cat([x.unsqueeze(-1), xx], dim=-1)
    xx = xx * w
    return torch.sum(xx, dim=-1)

And a short test:

x = torch.tensor([[[22, 34, 22, 18, 23, 34, 45, 40, 47, 46]]])
period = 3
print(sync_EMA(x, period))

tensor([[[22.0000, 30.0000, 26.0000, 22.0000, 21.1667, 27.6667, 37.6667,
          40.6667, 44.3333, 45.3333]]])

A couple of caveats.

  1. This assumes your sequence is oldest to newest. If not, then just flip it before the function on dim=2 and then flip it back.
  2. Your function did not address the oldest values meaningfully, likely resulting in you throwing them out later. The above function treats the oldest values as though the period starts from 1 and continues up to the “period” parameter set.
  3. This function assumes the same dim shape as in your stated function (although it can easily be adapted to others).

If this solves the issue for you, please select this as the answer so others can find it more easily.

1 Like

Thank you for the reply. I might have a slightly different understanding of EMA. What I want is a simple IIR filter like:
y[n] = c * y[n-1] + (1-c) * x[n]
Therefore, y[n] should depend on all values of x[0:n]. Please let me know if I missed something.

EMA can be handled non-iteratively if you don’t require all of the values outside of the period range provided. This is an approximate method.

I am using the definition from here, but slightly modified in order to normalize the initial values in the sequence:

But what it does is drops off the values outside of the period, making it more responsive to recent data. Here are some examples, compared with vanilla iterative EMAs.

image

image

image

Those are with periods of 21, 10, and 5 respectively.

And here is an example with a rising and falling sequence, period of 5:

image

Anyways, in order to make the calculation asynchronous, and faster for calculation, you’ll need an approximate method. Iterative will always be much slower because you bottleneck the calculation to one processor.

1 Like

Hi Tiance!

At the cost of materializing a length x length matrix (which could be costly
if length is large) you can reproduce your exponential-moving-average
computation with out a for-loop as follows:

>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> x = torch.randn (2, 3, 10)   # length = 10
>>> w = 0.01
>>>
>>> # for loop version
>>> y = torch.empty_like(x)
>>> y[:, :, 0] = x[:, :, 0]
>>> for t in range(1, y.shape[-1]):
...     y[:, :, t] = (1-w) * y[:, :, t-1] + w * x[:, :, t]
...
>>> # loop-free version
>>> n = x.shape[-1]
>>> p = (1 - w)**torch.arange (n + 1)                 # powers of (1 - weight)
>>> v = p.repeat (n).reshape (n + 1, n)[:-1].triu()   # length x length matrix of powers
>>>
>>> yB = w * x @ v + (1 - w) * x[..., 0, None] @ v[None, 0]
>>>
>>> torch.allclose (y, yB)
True

Edit: Note that for large values of length and small values of the quantity
being exponentiated (in the above code, 1 - w), the iterative for-loop
version is likely to be more numerically stable.

Best.

K. Frank

2 Likes

Thanks for the explanation!

1 Like

Thanks, I’ll try and see how much speedup I get. Neat trick to get the triangular weight matrix!

1 Like

Hello @KFrank ! Always like seeing your answers because the solutions are usually elegant and show just what Pytorch is capable of.

In this case, however, I think that approach might actually take longer than an iterative loop as v grows exponentially(squared) with the length of the sequence.

I just ran a few time tests with all three approaches, but with a longer sequence, to see which does better:

import torch

def sync_EMA(x, period: int):  # x size of (batch, channels, sequence)
    xy = torch.arange(start=1, end=x.size()[-1] + 1)
    xy_mask = xy > period
    xy[xy_mask] = period
    xz = torch.arange(start=0, end=period)

    xy = torch.cat([xy.unsqueeze(0)] * xz.size()[0])
    xz = torch.stack([xz] * xy.size()[1], dim=1)

    w = 2 * (xy - xz) / (xy * (xy + 1))
    w = w.T
    w = w.unsqueeze(0).unsqueeze(0)
    xx = torch.stack(
        [torch.cat([torch.zeros((x.size()[0], x.size()[1], p)), x[..., :-p]], dim=-1) for p in range(1, period)],
        dim=-1)

    xx = torch.cat([x.unsqueeze(-1), xx], dim=-1)
    xx = xx * w
    return torch.sum(xx, dim=-1)


def sync_EMA_2(x, period: int):  # x size of (batch, channels, sequence)
    y = torch.zeros_like(x)
    mult = 2 / (period + 1)
    for i in range(x.size()[-1]):
        if i == 0:
            y[:, :, i] = x[:, :, i]
        else:
            y[:, :, i] = x[:, :, i] * mult + y[:, :, i - 1] * (1 - mult)

    return y

def sync_EMA_3(x, period: int):  # x size of (batch, channels, sequence)
    n = x.shape[-1]
    w=1/period
    p = (1 - w) ** torch.arange(n + 1)  # powers of (1 - weight)
    v = p.repeat(n).reshape(n + 1, n)[:-1].triu()  # length x length matrix of powers
    yB = w * x @ v + (1 - w) * x[..., 0, None] @ v[None, 0]
    return yB

seq_len=50000
price_rand=torch.randn(seq_len)
base = 1000
seq_sine_period=30

y=(base/10+price_rand*20)*torch.sin(torch.arange(seq_len)/seq_sine_period)+base
period = 21

y=y.view(1,1,-1)
import time
starttime1=time.time()
EMA1=sync_EMA(y, period)
print(time.time()-starttime1)
starttime2=time.time()
EMA2=sync_EMA_2(y, period)
print(time.time()-starttime2)
starttime3=time.time()
EMA3=sync_EMA_3(y, period)
print(time.time()-starttime3)

Those processed in:

0.09100008010864258
2.7660350799560547
4.646963834762573

Additionally, if the sequence is longer, it can run into memory issues. I’m not sure what @wangtiance application is, but from my own work, working with sequences in excess of 300k length is fairly common.

But I really like the answer for its ingenuity! Bookmarking it, by the way.

1 Like

Hi J!

Two comments on performance:

In use cases where the “weight” (maps to your period) doesn’t change
one would typically precompute the weight matrix (my v). Then the (core
of the) moving-average computation becomes a tensor-matrix multiplication,
something that pytorch is very good at.

Note, even if the length (your seq_len) changes, you can still precompute.
For smaller lengths you can just slice into the larger-length weight matrix.
(Tensor-matrix-slice multiplication is also something pytorch is very good at.)

If you are backpropagating, with respect to either x or period (or both),
I could well imagine the for-loop version imposing a big performance hit.

If you are backpropagating with respect to period, you are presumably
training period, so you wouldn’t bee able to precompute the weight matrix
(or you would at least have to recompute it after each optimizer step). I
haven’t done any timing tests, but I could well imagine that backpropagating
through computing the weight matrix would nonetheless be cheaper than
backpropagating through the for-loop.

Yes, if seq_len is large, v could become unmanageably large.

Best.

K. Frank

My input data is extracted from audio features, and typically just a few thousand points. So I guess memory is not a big issue. Anyway, approximating the IIR by a 20 or 30 order FIR might just work well.

1 Like

Hi Tiance an J!

My workstation gpu (and cpu) can accommodate a seq_len of a few
thousand running the loop-free version.

A few more comments about performance: The sync_EMA() version
is really not an apples-to-apples comparison as it doesn’t reproduce
the computation you originally posted. (If you don’t need your original
exponential-moving-average formulation, there are lots of ways you
might speed things up.)

I’ve run some more systematic timings. The precomputed-weight-matrix
version outperforms sync_EMA() on my gpu for seq_len up to 10,000
(about the largest that fits on the gpu) and for seq_len up to 5,000 on
the same machine’s cpu.

Here is an expanded version of J’s timing code:

import torch
print (torch.__version__)

import time

_ = torch.manual_seed (2022)

device = 'cpu'
if  torch.cuda.is_available():
    device = 'cuda'

def sync_EMA(x, period: int):  # x size of (batch, channels, sequence)
    xy = torch.arange(start=1, end=x.size()[-1] + 1, device = device)
    xy_mask = xy > period
    xy[xy_mask] = period
    xz = torch.arange(start=0, end=period, device = device)
    
    xy = torch.cat([xy.unsqueeze(0)] * xz.size()[0])
    xz = torch.stack([xz] * xy.size()[1], dim=1)
    
    w = 2 * (xy - xz) / (xy * (xy + 1))
    w = w.T
    w = w.unsqueeze(0).unsqueeze(0)
    xx = torch.stack(
        [torch.cat([torch.zeros((x.size()[0], x.size()[1], p,), device = device), x[..., :-p]], dim=-1) for p in range(1, period)],
        dim=-1)
    
    xx = torch.cat([x.unsqueeze(-1), xx], dim=-1)
    xx = xx * w
    return torch.sum(xx, dim=-1)

def sync_EMA_2(x, period: int):  # x size of (batch, channels, sequence)
    y = torch.zeros_like(x)
    mult = 2 / (period + 1)
    for i in range(x.size()[-1]):
        if i == 0:
            y[:, :, i] = x[:, :, i]
        else:
            y[:, :, i] = x[:, :, i] * mult + y[:, :, i - 1] * (1 - mult)
    
    return y

def sync_EMA_3(x, period: int):  # x size of (batch, channels, sequence)
    n = x.shape[-1]
    # w=1/period
    w = 2 / (period + 1)
    p = (1 - w) ** torch.arange(n + 1, device = device)  # powers of (1 - weight)
    v = p.repeat(n).reshape(n + 1, n)[:-1].triu()  # length x length matrix of powers
    yB = w * x @ v + (1 - w) * x[..., 0, None] @ v[None, 0]
    return yB

periodPrecompute = None
nPrecompute = None
vPrecompute = None

def sync_EMA_3b (x, period: int):
    global periodPrecompute, nPrecompute, vPrecompute
    n = x.shape[-1]
    w = 2 / (period + 1)
    if  vPrecompute is None  or  period != periodPrecompute  or  n != nPrecompute:
        p = (1 - w) ** torch.arange(n + 1, device = device)
        v = p.repeat(n).reshape(n + 1, n)[:-1].triu()
        vPrecompute = v
        nPrecompute = n
        periodPrecompute = period
    yB = w * x @ vPrecompute + (1 - w) * x[..., 0, None] @ vPrecompute[None, 0]
    return yB

nWarm = 3
nTime = 10

for  seq_len in (1000, 2000, 5000, 10000):
    price_rand = torch.randn (seq_len, device = device)
    base = 1000
    seq_sine_period=30
    
    y = (base/10+price_rand*20)*torch.sin (torch.arange (seq_len, device = device)/seq_sine_period)+base
    period = 21
    
    y=y.view(1,1,-1)

    # warmup
    for  i in range (nWarm):
        _ = sync_EMA (y, period)
        _ = sync_EMA_2 (y, period)
        _ = sync_EMA_3 (y, period)
        _ = sync_EMA_3b (y, period)
    
    # timings
    
    print ('timings, seq_len:', seq_len)
    print ('device:', device)
    
    if  device == 'cuda':
        torch.cuda.synchronize()
    
    starttime1 = time.time()
    for  i in range (nTime):
        EMA1 = sync_EMA (y, period)
    
    if  device == 'cuda':
        torch.cuda.synchronize()
    
    print ('sync_EMA:   ', (time.time() - starttime1) / nTime)
    
    if  device == 'cuda':
        torch.cuda.synchronize()
    
    starttime2 = time.time()
    for  i in range (nTime):
        EMA2 = sync_EMA_2 (y, period)
    
    if  device == 'cuda':
        torch.cuda.synchronize()
    
    print ('sync_EMA_2: ', (time.time() - starttime2) / nTime)
    
    if  device == 'cuda':
        torch.cuda.synchronize()
    
    starttime3 = time.time()
    for  i in range (nTime):
        EMA3 = sync_EMA_3 (y, period)
    
    if  device == 'cuda':
        torch.cuda.synchronize()
    print ('sync_EMA_3: ', (time.time() - starttime3) / nTime)
    
    if  device == 'cuda':
        torch.cuda.synchronize()
    
    starttime3b = time.time()
    for  i in range (nTime):
        EMA3b = sync_EMA_3b (y, period)
    
    if  device == 'cuda':
        torch.cuda.synchronize()
    
    print ('sync_EMA_3b:', (time.time() - starttime3b) / nTime)
    
    print ('check EMA1: ', torch.allclose (EMA1, EMA2))
    print ('check EMA3: ', torch.allclose (EMA3, EMA2))
    print ('check EMA3b:', torch.allclose (EMA3b, EMA2))
    
    print ('check EMA3b, EMA3: ', torch.equal (EMA3b, EMA3))
    
    print ('EMA1, EMA2 max diff: ', (EMA1 - EMA2).abs().max())

Here are the cpu timings:

1.12.0
timings, seq_len: 1000
device: cpu
sync_EMA:    0.0014473199844360352
sync_EMA_2:  0.05421710014343262
sync_EMA_3:  0.0019309282302856444
sync_EMA_3b: 0.00019500255584716796
check EMA1:  False
check EMA3:  True
check EMA3b: True
check EMA3b, EMA3:  True
EMA1, EMA2 max diff:  tensor(11.3462)
timings, seq_len: 2000
device: cpu
sync_EMA:    0.0019614458084106444
sync_EMA_2:  0.10763685703277588
sync_EMA_3:  0.005632424354553222
sync_EMA_3b: 0.001221442222595215
check EMA1:  False
check EMA3:  True
check EMA3b: True
check EMA3b, EMA3:  True
EMA1, EMA2 max diff:  tensor(11.2375)
timings, seq_len: 5000
device: cpu
sync_EMA:    0.009285163879394532
sync_EMA_2:  0.26809670925140383
sync_EMA_3:  0.039087104797363284
sync_EMA_3b: 0.006497597694396973
check EMA1:  False
check EMA3:  True
check EMA3b: True
check EMA3b, EMA3:  True
EMA1, EMA2 max diff:  tensor(11.2563)
timings, seq_len: 10000
device: cpu
sync_EMA:    0.008062648773193359
sync_EMA_2:  0.5337584495544434
sync_EMA_3:  0.14901502132415773
sync_EMA_3b: 0.024390339851379395
check EMA1:  False
check EMA3:  True
check EMA3b: True
check EMA3b, EMA3:  True
EMA1, EMA2 max diff:  tensor(11.8373)

And the gpu timings:

1.12.0
timings, seq_len: 1000
device: cuda
sync_EMA:    0.001915717124938965
sync_EMA_2:  0.10515487194061279
sync_EMA_3:  0.00032756328582763674
sync_EMA_3b: 0.00014753341674804686
check EMA1:  False
check EMA3:  True
check EMA3b: True
check EMA3b, EMA3:  True
EMA1, EMA2 max diff:  tensor(11.0728, device='cuda:0')
timings, seq_len: 2000
device: cuda
sync_EMA:    0.0024464607238769533
sync_EMA_2:  0.20999035835266114
sync_EMA_3:  0.001055598258972168
sync_EMA_3b: 0.00021555423736572266
check EMA1:  False
check EMA3:  True
check EMA3b: True
check EMA3b, EMA3:  True
EMA1, EMA2 max diff:  tensor(11.3566, device='cuda:0')
timings, seq_len: 5000
device: cuda
sync_EMA:    0.004022955894470215
sync_EMA_2:  0.5285399198532105
sync_EMA_3:  0.005906414985656738
sync_EMA_3b: 0.0010313272476196289
check EMA1:  False
check EMA3:  True
check EMA3b: True
check EMA3b, EMA3:  True
EMA1, EMA2 max diff:  tensor(11.3903, device='cuda:0')
timings, seq_len: 10000
device: cuda
sync_EMA:    0.006554818153381348
sync_EMA_2:  1.056663703918457
sync_EMA_3:  0.023943758010864256
sync_EMA_3b: 0.00549619197845459
check EMA1:  False
check EMA3:  True
check EMA3b: True
check EMA3b, EMA3:  True
EMA1, EMA2 max diff:  tensor(11.7476, device='cuda:0')

Again, sync_EMA (EMA1) computes something different than your
original version (sync_EMA_2 / EMA2). Whether that matters would
depend on the details of your use case.

The precomputed-weight-matrix version (sync_EMA_3b / EMA3b) both
agrees with your computation and outperforms the other versions for
rather large values of seq_len, in spite of its seq_len**2 inefficiency.

This just underscores the general rule that python loops kill you in
performance relative to (python) loop-free tensor operations, especially
on a gpu.

(As an aside, if performance remains an issue for you, the seq_len**2
cost can be eliminated with rather more programming effort.)

Best.

K. Frank

1 Like

Again, sync_EMA (EMA1 ) computes something different than your
original version (sync_EMA_2 / EMA2 ). Whether that matters would
depend on the details of your use case.

At an EMA period of 199 (which works out to a multiplier of 0.01, as in the original post), the difference between the two EMA definitions averages approximately 0.12% or one part in a thousandth. Additionally, storing the sync_EMA’s w matrix, as you demonstrated with v and n, can also provide a speed-up in reuse. The difference is the size of the w matrix is seq_len x period and grows linearly with increasing seq_len or period. Not a bad tradeoff, if working with longer sequences.

Hello @J_Johnson and @KFrank , I ended up using the precomputed matrix version and the speed bottleneck is gone.

The amount of help I’m getting from you is incredible. Thank you.

1 Like

Hi Tiance and J!

A brief follow-up on our discussion:

The bulk of the exponential-moving-average computation can be packaged
as a convolution.

This suggests the following scheme: Eliminate the length**2 cost in
space by using a convolution and eliminate the length**2 cost in time
by truncating the exponentially-decaying convolution kernel comfortably
below some conservative estimate of round-off error.

Here is a script with such an implementation and some timing code:

import torch
print (torch.__version__)

import math

import time

_ = torch.manual_seed (2022)

devices = ['cpu']
if  torch.cuda.is_available():
    devices.append ('cuda')

# exponential moving average using conv1d with precomputed 1d kernel
class EMA (torch.nn.Module):
    def __init__ (self, alpha = 0.10, eps = 1.e-8, maxSeq = None):
        super().__init__()
        if  alpha <= 0.0  or  alpha >= 1.0:
            raise ValueError ('bad alpha')
        self.alpha = alpha     # exponential-moving-average "weight"
        self.eps = eps         # conservative (1 - alpha)**n cutoff
        self.maxSeq = maxSeq   # maximum sequence length that will be passed in
        
        self.albar = 1.0 - self.alpha
        self.nKernel = int (math.log (eps) / math.log (self.albar))   # truncated kernel length
        
        if  not self.maxSeq is None  and  self.nKernel > self.maxSeq:
            self.nKernel = self.maxSeq   # truncate kernel with maxSeq rather than eps
        else:
            self.maxSeq = None           # full (truncated) kernel regardless of sequence length
        
        kern = self.albar**(torch.arange (self.nKernel, 0, -1) - 1).unsqueeze (0).unsqueeze (0)
        self.register_buffer ('kFlip', self.albar * kern.flip (-1))   # "flipped" kernel for x[0] term of ema
        kern *= self.alpha
        self.register_buffer ('kernel', kern)
    
    def forward (self, x):
        if  x.size (-1) < self.nKernel:
            kern = self.kernel[:, :, -x.size (-1):]
            kflp = self.kFlip[:, :, :x.size (-1)] 
        else:
            kern = self.kernel
            kflp = torch.zeros (1, 1, x.size (-1), device = self.kFlip.device)
            kflp[0, 0, :self.nKernel] = self.kFlip
        
        y = x.unsqueeze (-2).unsqueeze (-2)   # add dummy channel and possible batch dimensions
        y = y.flatten (0, -2).unsqueeze (1)   # merge leading dimensions and add channel dimension
        z = torch.nn.functional.conv1d (y, kern, padding = kern.size (-1) - 1)[:, :, :-kern.size (-1) + 1] + kflp * y[:, :, 0:1]
        z = z.view (x.shape)   # restore dimensions
        return z

# for-loop version of exponential moving average
def loopEMA (x, alpha):  # x size of (batch, channels, sequence)
    y = torch.zeros_like (x)
    for i in range(x.size()[-1]):
        if i == 0:
            y[:, :, i] = x[:, :, i]
        else:
            y[:, :, i] = x[:, :, i] * alpha + y[:, :, i - 1] * (1 - alpha)
    
    return y

# "sync" version of exponential moving average
def sync_EMA(x, period: int):  # x size of (batch, channels, sequence)
    xy = torch.arange(start=1, end=x.size()[-1] + 1, device = x.device)
    xy_mask = xy > period
    xy[xy_mask] = period
    xz = torch.arange(start=0, end=period, device = x.device)
    
    xy = torch.cat([xy.unsqueeze(0)] * xz.size()[0])
    xz = torch.stack([xz] * xy.size()[1], dim=1)
    
    w = 2 * (xy - xz) / (xy * (xy + 1))
    w = w.T
    w = w.unsqueeze(0).unsqueeze(0)
    xx = torch.stack(
        [torch.cat([torch.zeros((x.size()[0], x.size()[1], p,), device = x.device), x[..., :-p]], dim=-1) for p in range(1, period)],
        dim=-1)
    
    xx = torch.cat([x.unsqueeze(-1), xx], dim=-1)
    xx = xx * w
    return torch.sum(xx, dim=-1)

# precomputed weight-matrix version of exponential moving average
# global variables for precomputation
alpahPrecompute = None
nPrecompute = None
devPrecompute = None
vPrecompute = None

def weightEMA (x, alpha):
    global alphaPrecompute, nPrecompute, devPrecompute, vPrecompute
    n = x.shape[-1]
    dev = x.device
    if  vPrecompute is None  or  alpha != alphaPrecompute  or  n != nPrecompute  or  dev != devPrecompute:
        p = (1 - alpha) ** torch.arange(n + 1, device = dev)
        v = p.repeat(n).reshape(n + 1, n)[:-1].triu()
        vPrecompute = v
        nPrecompute = n
        alphaPrecompute = alpha
    y = alpha * x @ vPrecompute + (1 - alpha) * x[..., 0, None] @ vPrecompute[None, 0]
    return y

# run timings

period = 21
alpha = 2 / (period + 1)

ema = EMA (alpha)   # instantiate convolutional EMA function object

nWarm = 3
nTime = 10

for  dev in devices:
    ema = ema.to (dev)
    for  seq_len in (100, 300, 1000, 3000, 10000, 30000, 100000, 300000):
        
        # create test data
        price_rand = torch.randn (seq_len, device = dev)
        base = 1000
        seq_sine_period = 30
        x = (base / 10 + price_rand * 20) * torch.sin (torch.arange (seq_len, device = dev) / seq_sine_period) + base
        x = x.view (1, 1, -1)
    
        # warmup
        for  i in range (nWarm):
            _ = loopEMA (x, alpha)
            _ = sync_EMA (x, period)
            if  (dev == 'cpu' and seq_len <= 30000)  or  (dev == 'cuda' and seq_len <= 10000):
                _ = weightEMA (x, alpha)
            _ = ema (x)
    
        # timings
        
        print ('seq_len: %7d  (%s)' % (seq_len, dev))
        
        if  dev == 'cuda':
            torch.cuda.synchronize()
        start = time.time()
        for  i in range (nTime):
            loopEma = loopEMA (x, alpha)
        if  dev == 'cuda':
            torch.cuda.synchronize()
        t = (time.time() - start) / nTime
        print ('   loopEMA:   %8.5f' % (t,))
        
        if  dev == 'cuda':
            torch.cuda.synchronize()
        start = time.time()
        for  i in range (nTime):
            syncEma = sync_EMA (x, period)
        if  dev == 'cuda':
            torch.cuda.synchronize()
        t = (time.time() - start) / nTime
        d = (syncEma - loopEma).abs().max().item()
        print ('   sync_EMA:  %8.5f  (max-diff: %8.5f)' % (t, d))
        
        if  (dev == 'cpu' and seq_len <= 30000)  or  (dev == 'cuda' and seq_len <= 10000):
            if  dev == 'cuda':
                torch.cuda.synchronize()
            start = time.time()
            for  i in range (nTime):
                weightEma = weightEMA (x, alpha)
            if  dev == 'cuda':
                torch.cuda.synchronize()
            t = (time.time() - start) / nTime
            d = (weightEma - loopEma).abs().max().item()
            print ('   weightEMA: %8.5f  (max-diff: %8.5f)' % (t, d))
        
        if  dev == 'cuda':
            torch.cuda.synchronize()
        start = time.time()
        for  i in range (nTime):
            convEma = ema (x)
        if  dev == 'cuda':
            torch.cuda.synchronize()
        t = (time.time() - start) / nTime
        d = (convEma - loopEma).abs().max().item()
        print ('   conv-EMA:  %8.5f  (max-diff: %8.5f)' % (t, d))

And here is its output:

1.12.0
seq_len:     100  (cpu)
   loopEMA:    0.00561
   sync_EMA:   0.00066  (max-diff:  9.06934)
   weightEMA:  0.00017  (max-diff:  0.00024)
   conv-EMA:   0.00011  (max-diff:  0.00024)
seq_len:     300  (cpu)
   loopEMA:    0.01680
   sync_EMA:   0.00085  (max-diff: 11.38959)
   weightEMA:  0.00024  (max-diff:  0.00024)
   conv-EMA:   0.00014  (max-diff:  0.00024)
seq_len:    1000  (cpu)
   loopEMA:    0.05427
   sync_EMA:   0.00138  (max-diff: 11.06018)
   weightEMA:  0.00144  (max-diff:  0.00024)
   conv-EMA:   0.00016  (max-diff:  0.00037)
seq_len:    3000  (cpu)
   loopEMA:    0.16156
   sync_EMA:   0.00826  (max-diff: 11.93427)
   weightEMA:  0.01656  (max-diff:  0.00037)
   conv-EMA:   0.00029  (max-diff:  0.00037)
seq_len:   10000  (cpu)
   loopEMA:    0.53862
   sync_EMA:   0.00840  (max-diff: 11.80554)
   weightEMA:  0.14688  (max-diff:  0.00049)
   conv-EMA:   0.00091  (max-diff:  0.00037)
seq_len:   30000  (cpu)
   loopEMA:    1.61875
   sync_EMA:   0.02766  (max-diff: 11.90173)
   weightEMA:  1.30305  (max-diff:  0.00037)
   conv-EMA:   0.01089  (max-diff:  0.00037)
seq_len:  100000  (cpu)
   loopEMA:    5.39038
   sync_EMA:   0.10049  (max-diff: 11.94653)
   conv-EMA:   0.03603  (max-diff:  0.00049)
seq_len:  300000  (cpu)
   loopEMA:   16.20274
   sync_EMA:   0.31397  (max-diff: 12.11890)
   conv-EMA:   0.10671  (max-diff:  0.00049)
seq_len:     100  (cuda)
   loopEMA:    0.01021
   sync_EMA:   0.00137  (max-diff: 10.35114)
   weightEMA:  0.00028  (max-diff:  0.00024)
   conv-EMA:   0.00018  (max-diff:  0.00024)
seq_len:     300  (cuda)
   loopEMA:    0.03067
   sync_EMA:   0.00150  (max-diff: 10.89624)
   weightEMA:  0.00029  (max-diff:  0.00024)
   conv-EMA:   0.00020  (max-diff:  0.00024)
seq_len:    1000  (cuda)
   loopEMA:    0.10285
   sync_EMA:   0.00187  (max-diff: 11.38379)
   weightEMA:  0.00033  (max-diff:  0.00037)
   conv-EMA:   0.00021  (max-diff:  0.00049)
seq_len:    3000  (cuda)
   loopEMA:    0.30884
   sync_EMA:   0.00311  (max-diff: 11.71997)
   weightEMA:  0.00218  (max-diff:  0.00037)
   conv-EMA:   0.00022  (max-diff:  0.00037)
seq_len:   10000  (cuda)
   loopEMA:    1.03287
   sync_EMA:   0.00733  (max-diff: 11.49841)
   weightEMA:  0.02417  (max-diff:  0.00037)
   conv-EMA:   0.00021  (max-diff:  0.00037)
seq_len:   30000  (cuda)
   loopEMA:    3.08627
   sync_EMA:   0.01811  (max-diff: 11.88159)
   conv-EMA:   0.00024  (max-diff:  0.00049)
seq_len:  100000  (cuda)
   loopEMA:   10.27835
   sync_EMA:   0.05575  (max-diff: 12.02441)
   conv-EMA:   0.00054  (max-diff:  0.00049)
seq_len:  300000  (cuda)
   loopEMA:   30.98715
   sync_EMA:   0.18238  (max-diff: 11.97083)
   conv-EMA:   0.00150  (max-diff:  0.00049)

(Timings for weightEMA have been skipped for the few largest sequence
lengths because the full weight matrix then exceeds the memory of the test
machine.)

In this version, the convolution kernel is precomputed when the EMA class
is instantiated. I haven’t performed timings, but I suspect that the bulk of the
performance increase will still be achieved if the kernel is computed on the
fly.

Best.

K. Frank

1 Like

Nice work! That’s very interesting.

I agree, convolutions have a lot of untapped potential. Using them for SMAs, too. But was thinking how they might be useful for NLP. That dot product attention is very inefficient for large language models!