StopIteration error with TransformerEncoderLayer

Hi there I tried posting this under the FX category but doesn’t seem like that’s an active topic, so I thought I’d try my luck here.

I’m trying to run the fx profiling tutorial in tutorials/fx_profiling_tutorial.py at master · pytorch/tutorials · GitHub on a single nn.TransformerEncoderLayer as opposed to the resnet in the example and I keep running into a StopIteration error. Why is this happening? All I did was replace the resnet with a transformer encoder layer. Here is the code:

# -*- coding: utf-8 -*-
"""
(beta) Building a Simple CPU Performance Profiler with FX
*******************************************************
**Author**: `James Reed <https://github.com/jamesr66a>`_

In this tutorial, we are going to use FX to do the following:

1) Capture PyTorch Python code in a way that we can inspect and gather
   statistics about the structure and execution of the code
2) Build out a small class that will serve as a simple performance "profiler",
   collecting runtime statistics about each part of the model from actual
   runs.

"""

######################################################################
# For this tutorial, we are going to use the torchvision ResNet18 model
# for demonstration purposes.

import torch
import torch.fx
import torchvision.models as models

model = torch.nn.TransformerEncoderLayer(d_model=512, nhead=8)
model.eval()

######################################################################
# Now that we have our model, we want to inspect deeper into its
# performance. That is, for the following invocation, which parts
# of the model are taking the longest?
input = torch.randn(10, 32, 512)
output = model(input)

######################################################################
# A common way of answering that question is to go through the program
# source, add code that collects timestamps at various points in the
# program, and compare the difference between those timestamps to see
# how long the regions between the timestamps take.
#
# That technique is certainly applicable to PyTorch code, however it
# would be nicer if we didn't have to copy over model code and edit it,
# especially code we haven't written (like this torchvision model).
# Instead, we are going to use FX to automate this "instrumentation"
# process without needing to modify any source.

######################################################################
# First, let's get some imports out of the way (we will be using all
# of these later in the code).

import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter

######################################################################
# .. note::
#     ``tabulate`` is an external library that is not a dependency of PyTorch.
#     We will be using it to more easily visualize performance data. Please
#     make sure you've installed it from your favorite Python package source.

######################################################################
# Capturing the Model with Symbolic Tracing
# -----------------------------------------
# Next, we are going to use FX's symbolic tracing mechanism to capture
# the definition of our model in a data structure we can manipulate
# and examine.

# traced_rn18 = torch.fx.symbolic_trace(rn18)
# print(traced_rn18.graph)

######################################################################
# This gives us a Graph representation of the ResNet18 model. A Graph
# consists of a series of Nodes connected to each other. Each Node
# represents a call-site in the Python code (whether to a function,
# a module, or a method) and the edges (represented as ``args`` and ``kwargs``
# on each node) represent the values passed between these call-sites. More
# information about the Graph representation and the rest of FX's APIs ca
# be found at the FX documentation https://pytorch.org/docs/master/fx.html.


######################################################################
# Creating a Profiling Interpreter
# --------------------------------
# Next, we are going to create a class that inherits from ``torch.fx.Interpreter``.
# Though the ``GraphModule`` that ``symbolic_trace`` produces compiles Python code
# that is run when you call a ``GraphModule``, an alternative way to run a
# ``GraphModule`` is by executing each ``Node`` in the ``Graph`` one by one. That is
# the functionality that ``Interpreter`` provides: It interprets the graph node-
# by-node.
#
# By inheriting from ``Interpreter``, we can override various functionality and
# install the profiling behavior we want. The goal is to have an object to which
# we can pass a model, invoke the model 1 or more times, then get statistics about
# how long the model and each part of the model took during those runs.
#
# Let's define our ``ProfilingInterpreter`` class:

class ProfilingInterpreter(Interpreter):
    def __init__(self, mod : torch.nn.Module):
        # Rather than have the user symbolically trace their model,
        # we're going to do it in the constructor. As a result, the
        # user can pass in any ``Module`` without having to worry about
        # symbolic tracing APIs
        gm = torch.fx.symbolic_trace(mod)
        super().__init__(gm)

        # We are going to store away two things here:
        #
        # 1. A list of total runtimes for ``mod``. In other words, we are
        #    storing away the time ``mod(...)`` took each time this
        #    interpreter is called.
        self.total_runtime_sec : List[float] = []
        # 2. A map from ``Node`` to a list of times (in seconds) that
        #    node took to run. This can be seen as similar to (1) but
        #    for specific sub-parts of the model.
        self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}

    ######################################################################
    # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
    # method is the top-level entrypoint for execution of the model. We will
    # want to intercept this so that we can record the total runtime of the
    # model.

    def run(self, *args) -> Any:
        # Record the time we started running the model
        t_start = time.time()
        # Run the model by delegating back into Interpreter.run()
        return_val = super().run(*args)
        # Record the time we finished running the model
        t_end = time.time()
        # Store the total elapsed time this model execution took in the
        # ProfilingInterpreter
        self.total_runtime_sec.append(t_end - t_start)
        return return_val

    ######################################################################
    # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
    # time it executes a single node. We will intercept this so that we
    # can measure and record the time taken for each individual call in
    # the model.

    def run_node(self, n : torch.fx.Node) -> Any:
        # Record the time we started running the op
        t_start = time.time()
        # Run the op by delegating back into Interpreter.run_node()
        return_val = super().run_node(n)
        # Record the time we finished running the op
        t_end = time.time()
        # If we don't have an entry for this node in our runtimes_sec
        # data structure, add one with an empty list value.
        self.runtimes_sec.setdefault(n, [])
        # Record the total elapsed time for this single invocation
        # in the runtimes_sec data structure
        self.runtimes_sec[n].append(t_end - t_start)
        return return_val

    ######################################################################
    # Finally, we are going to define a method (one which doesn't override
    # any ``Interpreter`` method) that provides us a nice, organized view of
    # the data we have collected.

    def summary(self, should_sort : bool = False) -> str:
        # Build up a list of summary information for each node
        node_summaries : List[List[Any]] = []
        # Calculate the mean runtime for the whole network. Because the
        # network may have been called multiple times during profiling,
        # we need to summarize the runtimes. We choose to use the
        # arithmetic mean for this.
        mean_total_runtime = statistics.mean(self.total_runtime_sec)

        # For each node, record summary statistics
        for node, runtimes in self.runtimes_sec.items():
            # Similarly, compute the mean runtime for ``node``
            mean_runtime = statistics.mean(runtimes)
            # For easier understanding, we also compute the percentage
            # time each node took with respect to the whole network.
            pct_total = mean_runtime / mean_total_runtime * 100
            # Record the node's type, name of the node, mean runtime, and
            # percent runtim
            node_summaries.append(
                [node.op, str(node), mean_runtime, pct_total])

        # One of the most important questions to answer when doing performance
        # profiling is "Which op(s) took the longest?". We can make this easy
        # to see by providing sorting functionality in our summary view
        if should_sort:
            node_summaries.sort(key=lambda s: s[2], reverse=True)

        # Use the ``tabulate`` library to create a well-formatted table
        # presenting our summary information
        headers : List[str] = [
            'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
        ]
        return tabulate.tabulate(node_summaries, headers=headers)

######################################################################
# .. note::
#       We use Python's ``time.time`` function to pull wall clock
#       timestamps and compare them. This is not the most accurate
#       way to measure performance, and will only give us a first-
#       order approximation. We use this simple technique only for the
#       purpose of demonstration in this tutorial.

######################################################################
# Investigating the Performance of ResNet18
# -----------------------------------------
# We can now use ``ProfilingInterpreter`` to inspect the performance
# characteristics of our ResNet18 model;

interp = ProfilingInterpreter(model)
interp.run(input)
print(interp.summary(True))

######################################################################
# There are two things we should call out here:
#
# * MaxPool2d takes up the most time. This is a known issue:
#   https://github.com/pytorch/pytorch/issues/51393
# * BatchNorm2d also takes up significant time. We can continue this
#   line of thinking and optimize this in the Conv-BN Fusion with FX
#   tutorial TODO: link
#
#
# Conclusion
# ----------
# As we can see, using FX we can easily capture PyTorch programs (even
# ones we don't have the source code for!) in a machine-interpretable
# format and use that for analysis, such as the performance analysis
# we've done here. FX opens up an exiciting world of possibilities for
# working with PyTorch programs.
#
# Finally, since FX is still in beta, we would be happy to hear any
# feedback you have about using it. Please feel free to use the
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
# you might have.