Torch.jit.fork makes code slower

I did a simple experiment and found out torch.jit.fork made code slower on both cpu and cuda

import torch
import torch.nn as nn

class WithoutFork(nn.Module):

    def __init__(self):
        super(WithoutFork, self).__init__()
        self.part1 = nn.Linear(10, 1000)
        self.part2 = nn.Linear(10, 1000)
        self.part3 = nn.Linear(10, 1000)
    
    def forward(self, x):

        x1 = self.part1(x)
        x2 = self.part2(x)
        x3 = self.part3(x)

        return x1 + x2 + x3

class WithFork(nn.Module):

    def __init__(self):
        super(WithFork, self).__init__()
        self.part1 = nn.Linear(10, 1000)
        self.part2 = nn.Linear(10, 1000)
        self.part3 = nn.Linear(10, 1000)
    
    def forward(self, x):

        f1 = torch.jit.fork(self.part1, x)
        f2 = torch.jit.fork(self.part2, x)
        f3 = torch.jit.fork(self.part3, x)

        fut = [f1, f2, f3]

        xs = [torch.jit.wait(f) for f in fut]

        return torch.stack(xs, 0).sum(0)


without_fork = WithoutFork().to('cuda')
with_fork = WithFork().to('cuda')
s_without_fork = torch.jit.script(without_fork)
s_with_fork = torch.jit.script(with_fork)
x = torch.randn(100, 10).cuda()

%%timeit
_ = without_fork(x)
>> The slowest run took 20.22 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 113 µs per loop

%%timeit
_ = with_fork(x)
>> The slowest run took 17.37 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 174 µs per loop

%%timeit
_ = s_without_fork(x)
>> The slowest run took 524.84 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 88 µs per loop

%%timeit
_ = s_with_fork(x)
>> The slowest run took 25.88 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 238 µs per loop

# CPU

without_fork = WithoutFork()
with_fork = WithFork()
s_without_fork = torch.jit.script(without_fork)
s_with_fork = torch.jit.script(with_fork)
x = torch.randn(100, 10)

%%timeit
_ = without_fork(x)
>> The slowest run took 6.67 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 528 µs per loop

%%timeit
_ = with_fork(x)
>> The slowest run took 8.37 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 785 µs per loop

%%timeit
_ = s_without_fork(x)
>> 1000 loops, best of 5: 516 µs per loop

%%timeit
_ = s_with_fork(x)
>> 1000 loops, best of 5: 829 µs per loop

So torchscript is indeed doing optimizations answering my own question torchscript-optimization?

[Check the number of inter-op parallelism threads you are using].

It could also be the case that your ops take so little time that the threading overheard is not worth the asynchronous processing.

1 Like