Problem in symbolically trace (torch.fx) nn.GRUCell/LSTMCell

Hello,
I am trying to symbolically trace a model containing an nn.GRUCell using torch.fx to prepare it for quantization.
I tried 2 approaches:
1. Define states of the mode using self.states
Here is the code:

import torch 
import torch.nn as nn
from pulse_ai.tiny.helper_functions import (prepare, profile)
from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx


class GRU(nn.Module):
    def __init__(self, input_size, hidden_size=None):
        super().__init__()
        if hidden_size is None:
            hidden_size = input_size
        # Intialize empty states
        self.states = torch.zeros(1, hidden_size)
        #self.register_buffer('states', states)
        self.rnn = nn.GRUCell(input_size, hidden_size)

    def forward(self, x):
        self.states = self.rnn(x, self.states) 
        return self.states
    
TestGRU = GRU(128)

dummy_input = torch.rand(1,128)

outputs = TestGRU(dummy_input)

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(TestGRU)

symbolic_traced.graph.print_tabular()

default_mapping = get_default_qconfig_mapping("qnnpack")

model_prepared = quantize_fx.prepare_fx(
                symbolic_traced,
                default_mapping,
                dummy_input,
)

This is the error I got

/home/ahmed/anaconda3/envs/tiny_pulse_env/bin/python /home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/Test.py
opcode       name    target    args         kwargs
-----------  ------  --------  -----------  --------
placeholder  x       x         ()           {}
get_attr     states  states    ()           {}
call_module  rnn     rnn       (x, states)  {}
output       output  output    (rnn,)       {}
Traceback (most recent call last):
  File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/Test.py", line 36, in <module>
    model_prepared = quantize_fx.prepare_fx(
                     ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/quantize_fx.py", line 380, in prepare_fx
    return _prepare_fx(
           ^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/quantize_fx.py", line 137, in _prepare_fx
    graph_module = _fuse_fx(
                   ^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/quantize_fx.py", line 89, in _fuse_fx
    return fuse(
           ^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse.py", line 115, in fuse
    env[node.name] = fused_graph.node_copy(node, load_arg)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/graph.py", line 1163, in node_copy
    args = map_arg(node.args, arg_transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 621, in map_arg
    return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 629, in map_aggregate
    t = tuple(map_aggregate(elem, fn) for elem in a)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 629, in <genexpr>
    t = tuple(map_aggregate(elem, fn) for elem in a)
              ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 639, in map_aggregate
    return fn(a)
           ^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 621, in <lambda>
    return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
                                      ^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse.py", line 86, in load_arg
    return map_arg(a, lambda node: env[node.name])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 621, in map_arg
    return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 639, in map_aggregate
    return fn(a)
           ^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 621, in <lambda>
    return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
                                      ^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse.py", line 86, in <lambda>
    return map_arg(a, lambda node: env[node.name])
                                   ~~~^^^^^^^^^^^
KeyError: 'rnn'

2. Define states using seld.register_buffer(ā€˜statesā€™,states)
Code

import torch 
import torch.nn as nn
from pulse_ai.tiny.helper_functions import (prepare, profile)
from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx


class GRU(nn.Module):
    def __init__(self, input_size, hidden_size=None):
        super().__init__()
        if hidden_size is None:
            hidden_size = input_size
        # Intialize empty states
        states = torch.zeros(1, hidden_size)
        self.register_buffer('states', states)
        self.rnn = nn.GRUCell(input_size, hidden_size)

    def forward(self, x):
        self.states = self.rnn(x, self.states) 
        return self.states
    
TestGRU = GRU(128)

dummy_input = torch.rand(1,128)

outputs = TestGRU(dummy_input)

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(TestGRU)

symbolic_traced.graph.print_tabular()

default_mapping = get_default_qconfig_mapping("qnnpack")

model_prepared = quantize_fx.prepare_fx(
                symbolic_traced,
                default_mapping,
                dummy_input,
)

Error:

Traceback (most recent call last):
  File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/Test.py", line 30, in <module>
    symbolic_traced : torch.fx.GraphModule = symbolic_trace(TestGRU)
                                             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 1109, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/Test.py", line 19, in forward
    self.states = self.rnn(x, self.states) 
    ^^^^^^^^^^^
  File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1665, in __setattr__
    raise TypeError("cannot assign '{}' as buffer '{}' "
TypeError: cannot assign 'Proxy(add)' as buffer 'states' (torch.Tensor or None expected)

In both approaches, I think the problem is when passing ā€œSELFā€.states in the forward function.
Can you suggest any solution?
( Maybe try to externally update the states of the cell! )

Hope to hearing from you soon,
Ahmed

yeah GRU is not symbolically traceable, but you can trace it with torchdynamo with the following code:

import torch
import torch._dynamo as torchdynamo
import copy
torchdynamo.config.allow_rnn = True


class RNNDynamicModel(torch.nn.Module):
    def __init__(self, mod_type):
        super().__init__()
        if mod_type == "GRU":
            self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float)
        if mod_type == "LSTM":
            self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float)

    def forward(self, x):
        x = self.mod(x)
        return x


niter = 10
example_inputs = (
    torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float)
    .unsqueeze(0)
    .repeat(niter, 1, 1),
)
model = RNNDynamicModel("GRU")
model, guards = torchdynamo.export(
    model,
    *copy.deepcopy(example_inputs),
    aten_graph=True,
    tracing_mode="real",
)

this will work with our new quantization tool in pytorch 2.0 export (pytorch/_quantize_pt2e.py at main Ā· pytorch/pytorch Ā· GitHub), Iā€™m trying to figure out how we can quantize it right now actually.

1 Like

if you need to use fx graph mode quantization, the closet is LSTM custom module quantization:

where you need to write a custom observed and quantized module for GRU (break down the single GRU function call to calls to linear and non-linearities) in order to quantize it.

Thank you for your quick reply!
Is it easier to quantize LSTM/GRU layers using Eager Mode?

I donā€™t think itā€™srelated to eager mode or fx graph mode, it depends on whether you need to customize how to quantize each submodule of the quantizable LSTM or not, if you can use a global qconfig for all submodules then it will be simpler, both in fx graph mode and eager mode

Thank you @jerryzh168 !
I just wanted to make sure that after this custom quantization, we wonā€™t be able to export the quantized layers to ONNX (at least for the moment!)

Thank you,

yeah I havenā€™t tried, but itā€™s likely not exportable to ONNX

HI Jerry, just saw this, where can we have some docs on pt2e export? I just see the code now but donā€™t know the usage and itā€™s advantages.

its in development, its still experimental, as a result the docs are not complete yet as design changes may still occur.

1 Like

please stay tuned, weā€™ll have something after the end of this week and Iā€™ll post here

1 Like

OK two docs here:

There are still some formatting issues (e.g. some code didnā€™t display properly) that will be fixed next week

Cool, just read the tutorial of this pt2e, and very glad to see that we can use less configs to quantize a model.
I struggled with fxā€™s too many configs for a while, I still cannot fully understand the differents of qconfig, qconfig_mapping, equalization_config and backend_configā€¦

1 Like