Sign Up Log In Is there a way to make the following code(signal filter) run faster?

I want to implement a filter like a(t) = 0.995*a(t-1) + (1-0.995)*in(t). I use TorchScript to speed up, but it is still very slow. Is there a better way to implement it?

class PostFilter(jit.ScriptModule):
    def __init__(self):
        super(PostFilter, self).__init__()

    @jit.script_method
    def forward(self, d_filter):
        b, c, t, f = d_filter.size()
        w = torch.zeros([b, c, 1, f], dtype=d_filter.dtype, device=d_filter.device)
        l = torch.jit.annotate(List[Tensor], [])
        for i in range(t):
            w = 0.995 * w + (1 - 0.995) * d_filter[:, :, i:i + 1]
            l += [w]
        d_filter_new = torch.cat(l, dim=2)
        return d_filter_new

In the current release (1.12.1) I’m seeing a 3x speedup when this method is scripted:

class PostFilterEager(torch.nn.Module):
    def __init__(self):
        super(PostFilterEager, self).__init__()

    def forward(self, d_filter):
        b, c, t, f = d_filter.size()
        w = torch.zeros([b, c, 1, f], dtype=d_filter.dtype, device=d_filter.device)
        l = []
        for i in range(t):
            w = 0.995 * w + (1 - 0.995) * d_filter[:, :, i:i + 1]
            l += [w]
        d_filter_new = torch.cat(l, dim=2)
        return d_filter_new


class PostFilter(torch.jit.ScriptModule):
    def __init__(self):
        super(PostFilter, self).__init__()

    @torch.jit.script_method
    def forward(self, d_filter):
        b, c, t, f = d_filter.size()
        w = torch.zeros([b, c, 1, f], dtype=d_filter.dtype, device=d_filter.device)
        l = torch.jit.annotate(List[Tensor], [])
        l = []
        for i in range(t):
            w = 0.995 * w + (1 - 0.995) * d_filter[:, :, i:i + 1]
            l += [w]
        d_filter_new = torch.cat(l, dim=2)
        return d_filter_new

# eager
# setup
model = PostFilter()
x = torch.randn(64, 64, 64, 64, device='cuda')

# warmup
for _ in range(10):
    out = model(x)

# profile
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(1000):
    out = model(x)
torch.cuda.synchronize()
t1 = time.perf_counter()
t_eager = (t1 - t0)/1000
print(t_eager)

# scripted
# setup
model = PostFilterEager()
x = torch.randn(64, 64, 64, 64, device='cuda')

# warmup
for _ in range(10):
    out = model(x)

# profile
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(1000):
    out = model(x)
torch.cuda.synchronize()
t1 = time.perf_counter()
t_scripted = (t1 - t0)/1000
print(t_scripted)

# speedup
speedup = t_scripted/t_eager
print(speedup)

What is your expectation regarding the performance of this operation?

1 Like

thank you for your reply. This is the part of my network that is the bottleneck of speed during training. I’m wondering if there is a better implementation to speed up this part.