How to visualize the profile guided optimization in torchscript?

I am currently studying torchscript. I came across the technique called profile guided optimization is being carried out in torchscript which gets every information about the tensor and it’s operation.

Profile guided optimization uses Prim:Profile to these information. My doubt is is there any way to visualize the graph (Intermediate Representation) with Prim:Profile and Prim:guard in pytorch 1.5?

Please help to with my doubt

Hi, yes you can run:

old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._jit_set_profiling_mode(True)
torch._C._jit_set_num_profiled_runs(num_runs)

run your scripted function/module for some number of times,

then run:

torch.jit.last_executed_optimized_graph()

Hi sir,
I can visualize Prim:Profile.

Can you please explain me in detail the solution you have provided?. I cannot understand the set of commands you have given.

old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._jit_set_profiling_mode(True)

this enables the profiling executor.

torch._C._jit_set_num_profiled_runs(num_runs)

how many profiling runs we want to do before we optimize the graph.

@torch.jit.script
def foo(x):
    if x.size(0) == 1:
        return 1
    else:
        return 2

old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._jit_set_profiling_mode(True)
torch._C._jit_set_num_profiled_runs(1)

foo(torch.rand([1, 2]))
print(torch.jit.last_executed_optimized_graph())
foo(torch.rand([1, 2]))
print(torch.jit.last_executed_optimized_graph())

gives

graph(%x.1 : Tensor):
  %1 : int = prim::Constant[value=0]() # test/test_jit.py:3742:22
  %2 : int = prim::Constant[value=1]() # test/test_jit.py:3742:28
  %3 : int = prim::Constant[value=2]() # test/test_jit.py:3745:23
  %4 : Tensor = prim::profile(%x.1)
  %5 : int = aten::size(%4, %1) # test/test_jit.py:3742:15
  %6 : bool = aten::eq(%5, %2) # test/test_jit.py:3742:15
  %7 : int = prim::If(%6) # test/test_jit.py:3742:12
    block0():
      -> (%2)
    block1():
      -> (%3)
   = prim::profile()
  return (%7)

first graph, still profiling

graph(%x.1 : Tensor):
  %2 : int = prim::Constant[value=1]() # test/test_jit.py:3742:28
  %10 : int = prim::BailoutTemplate_0()
  %9 : Double(1:2, 2:1) = prim::BailOut[index=0](%10, %x.1)
  return (%2)
with prim::BailoutTemplate_0 = graph(%x.1 : Tensor):
  %1 : Double(1:2, 2:1) = prim::BailOut[index=0](%x.1)
  %2 : int = prim::Constant[value=0]() # test/test_jit.py:3742:22
  %3 : int = prim::Constant[value=1]() # test/test_jit.py:3742:28
  %4 : int = prim::Constant[value=2]() # test/test_jit.py:3745:23
  %5 : int = aten::size(%1, %2) # test/test_jit.py:3742:15
  %6 : bool = aten::eq(%5, %3) # test/test_jit.py:3742:15
  %7 : int = prim::If(%6) # test/test_jit.py:3742:12
    block0():
      -> (%3)
    block1():
      -> (%4)
  return (%7)

second graph, optimized from profiles.
additionally, you can read
https://github.com/pytorch/pytorch/blob/53af9df557aff745edf24193ece784fd008c6f19/torch/csrc/jit/OVERVIEW.md#profiling-programs.

Hello sir,
I tried to visualize the runtime performance improvement made by convolution layer which I implemented from scratch Vs torchscript version of convolution layer Vs torch.nn.conv2d() module for 100 iterations with input (128,3,28,28), out_channel =64, kernel size=3.

Convolution layer from scratch in CUDA -> 9.366 seconds
torchscript convolution layer from scratch in CUDA -> 6.636 seconds
torch.nn.conv2d() -> 475.614 milliseconds.

My code

class conv2D(nn.Module):
  def __init__(self, in_channel, out_channel, kernel_size):
    super(conv2D,self).__init__()
    self.weight = torch.nn.Parameter(torch.ones(out_channel,in_channel,kernel_size, kernel_size))
    self.bias = torch.nn.Parameter(torch.zeros(out_channel))
    self.kernel_size = kernel_size
    self.in_channel = in_channel
    self.out_channel = out_channel

  def forward(self, image):
    img_height = image.shape[3]
    img_width = image.shape[2]
    batch_size = image.shape[0]
    out_height = img_height-self.kernel_size+1
    out_width = img_width-self.kernel_size+1

    output = torch.zeros(batch_size,self.out_channel,out_width,out_height)
    for k in range(batch_size):
      for i in range(out_height):
        for j in range(out_width):
          temp = torch.sum(image[k,:,j:j+self.kernel_size,i:i+self.kernel_size]*self.weight,dim=(1,2,3))
          output[k,:,i,j]=torch.add(temp,self.bias)
      return output

Scripting the model and running with a sample input to get an optimized graph

x = torch.ones(128,3,28,28).to("cuda")
c = conv2D(3,64,3).to("cuda")
c_s = torch.jit.script(c).to("cuda")
c_s(x)

Profiling both the scripted and normal method.

with torch.autograd.profiler.profile(use_cuda=True) as prof:
  with torch.no_grad():
    for i in range(100):
      c(x)
print(prof.table())
with torch.autograd.profiler.profile(use_cuda=True) as prof:
  with torch.no_grad():
    for i in range(100):
      c_s(x)
print(prof.table())

Is there any problem in my approach? and how to optimize even more?. I request you to help me with this problem.

Hi,

Sorry for delay. Currently we only generate new optimized kernels for a series of pointwise ops on GPU. If that is not part of your use case it is unlikely you will see speedup at the moment.

Sir, But loop unrolling won’t provide higher performance?