Cannot export call stacks for flamegraphs with Pytorch profiler?

Cross-posted as an issue: stacks file from profiler is empty · Issue #89406 · pytorch/pytorch · GitHub, but I’m 99% sure it’s not a bug and just me misusing the API somehow.

I have a small reproducible example below that should (theoretically) write out a .stacks file that I can feed into a flamegraph generator.

In reality, the file is empty.
Does anyone know why?
I can’t see anything obviously wrong with my usage of the API either.

import torch

class Mlp(torch.nn.Module):
    def __init__(self):
        self.fc1 = torch.nn.Linear(100, 100)
        self.act = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(100, 100)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.mlp1 = Mlp()
        self.mlp2 = Mlp()

    def forward(self, x):
        return self.mlp2(self.mlp1(x))

from import DataLoader
import torchdata.datapipes.iter as dp
from torch.profiler import profile, ProfilerActivity

def make_mock_dataloader():
    pipe = dp.IterableWrapper([torch.rand(100) for _ in range(1000)])
    return DataLoader(pipe, batch_size=32, num_workers=2, drop_last=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = Net().to(device=device)

target_traces = 2
traces_saved = 0

def trace_handler(prof: "torch.profiler.profile"):
    global traces_saved
    from os.path import join

    print("SAVING TRACE")

    tb_dir = join("./output2", "traces", str(traces_saved))
    handler = torch.profiler.tensorboard_trace_handler(
        tb_dir, worker_name=f"rank0"

    prof.export_stacks(path=join(tb_dir, f"rank0.cuda.stacks"), metric="self_cuda_time_total")
    prof.export_stacks(path=join(tb_dir, f"rank0.cpu.stacks"), metric="self_cpu_time_total")

    # print(

    traces_saved += 1
    if traces_saved == target_traces:

prof = profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    # profile_memory=True,
    # with_modules=True,
    # record_shapes=True,
        skip_first=5, wait=1, warmup=5, active=5, repeat=target_traces

for idx, batch in enumerate(make_mock_dataloader()):
    print(f"idx: {idx}")
    batch =
    out = net(batch)