Assert self.profiler error while calling torch.profiler.profile's key_averages()

Hi all, I am trying the new profiler released in 1.8.1. My torch version is 1.8.1+cu102.

My code (Basically I just followed torch.profiler — PyTorch 1.8.1 documentation, and uses Learning PyTorch with Examples — PyTorch Tutorials 1.8.1+cu102 documentation as the test code for profiling):

import torch
import torch.profiler

import math

# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)

# Prepare the input tensor (x, x^2, x^3).
p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(torch.nn.Linear(3, 1), torch.nn.Flatten(0, 1))
loss_fn = torch.nn.MSELoss(reduction="sum")

learning_rate = 1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

with torch.profiler.profile(
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=18),
) as p:
    for t in range(20):
        y_pred = model(xx)

        loss = loss_fn(y_pred, y)
        if t % 10 == 9:
            print(t, loss.item())

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

linear_layer = model[0]
print(
    f"Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3"
)

print(type(p))
print("=====================")
print(p.key_averages().table())

And I got the following output:

9 3212.458251953125
19 2424.00048828125
Result: y = 0.16131113469600677 + 0.11958461999893188 x + -0.23199747502803802 x^2 + 0.03287645801901817 x^3
<class 'torch.profiler.profiler.profile'>
=====================
Traceback (most recent call last):
  File ".../torch_profiler_test.py", line 45, in <module>
    print(p.key_averages().table())
  File ".../torch/profiler/profiler.py", line 325, in key_averages
    assert self.profiler
AssertionError

Does anyone know what went wrong? Thank you.

Nevermind, I’ve found that I missed p.step() in the loop. After modifying to

for t in range(20):
        y_pred = model(xx)

        loss = loss_fn(y_pred, y)
        if t % 10 == 9:
            print(t, loss.item())

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        p.step() # was missing in the code above

I can get the table output.