I want to implement a filter like a(t) = 0.995*a(t-1) + (1-0.995)*in(t). I use TorchScript to speed up, but it is still very slow. Is there a better way to implement it?
class PostFilter(jit.ScriptModule):
def __init__(self):
super(PostFilter, self).__init__()
@jit.script_method
def forward(self, d_filter):
b, c, t, f = d_filter.size()
w = torch.zeros([b, c, 1, f], dtype=d_filter.dtype, device=d_filter.device)
l = torch.jit.annotate(List[Tensor], [])
for i in range(t):
w = 0.995 * w + (1 - 0.995) * d_filter[:, :, i:i + 1]
l += [w]
d_filter_new = torch.cat(l, dim=2)
return d_filter_new
thank you for your reply. This is the part of my network that is the bottleneck of speed during training. I’m wondering if there is a better implementation to speed up this part.