Hello,
I am trying to symbolically trace a model containing an nn.GRUCell using torch.fx to prepare it for quantization.
I tried 2 approaches:
1. Define states of the mode using self.states
Here is the code:
import torch
import torch.nn as nn
from pulse_ai.tiny.helper_functions import (prepare, profile)
from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx
class GRU(nn.Module):
def __init__(self, input_size, hidden_size=None):
super().__init__()
if hidden_size is None:
hidden_size = input_size
# Intialize empty states
self.states = torch.zeros(1, hidden_size)
#self.register_buffer('states', states)
self.rnn = nn.GRUCell(input_size, hidden_size)
def forward(self, x):
self.states = self.rnn(x, self.states)
return self.states
TestGRU = GRU(128)
dummy_input = torch.rand(1,128)
outputs = TestGRU(dummy_input)
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(TestGRU)
symbolic_traced.graph.print_tabular()
default_mapping = get_default_qconfig_mapping("qnnpack")
model_prepared = quantize_fx.prepare_fx(
symbolic_traced,
default_mapping,
dummy_input,
)
This is the error I got
/home/ahmed/anaconda3/envs/tiny_pulse_env/bin/python /home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/Test.py
opcode name target args kwargs
----------- ------ -------- ----------- --------
placeholder x x () {}
get_attr states states () {}
call_module rnn rnn (x, states) {}
output output output (rnn,) {}
Traceback (most recent call last):
File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/Test.py", line 36, in <module>
model_prepared = quantize_fx.prepare_fx(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/quantize_fx.py", line 380, in prepare_fx
return _prepare_fx(
^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/quantize_fx.py", line 137, in _prepare_fx
graph_module = _fuse_fx(
^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/quantize_fx.py", line 89, in _fuse_fx
return fuse(
^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse.py", line 115, in fuse
env[node.name] = fused_graph.node_copy(node, load_arg)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/graph.py", line 1163, in node_copy
args = map_arg(node.args, arg_transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 621, in map_arg
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 629, in map_aggregate
t = tuple(map_aggregate(elem, fn) for elem in a)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 629, in <genexpr>
t = tuple(map_aggregate(elem, fn) for elem in a)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 639, in map_aggregate
return fn(a)
^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 621, in <lambda>
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse.py", line 86, in load_arg
return map_arg(a, lambda node: env[node.name])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 621, in map_arg
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 639, in map_aggregate
return fn(a)
^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/node.py", line 621, in <lambda>
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse.py", line 86, in <lambda>
return map_arg(a, lambda node: env[node.name])
~~~^^^^^^^^^^^
KeyError: 'rnn'
2. Define states using seld.register_buffer(āstatesā,states)
Code
import torch
import torch.nn as nn
from pulse_ai.tiny.helper_functions import (prepare, profile)
from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx
class GRU(nn.Module):
def __init__(self, input_size, hidden_size=None):
super().__init__()
if hidden_size is None:
hidden_size = input_size
# Intialize empty states
states = torch.zeros(1, hidden_size)
self.register_buffer('states', states)
self.rnn = nn.GRUCell(input_size, hidden_size)
def forward(self, x):
self.states = self.rnn(x, self.states)
return self.states
TestGRU = GRU(128)
dummy_input = torch.rand(1,128)
outputs = TestGRU(dummy_input)
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(TestGRU)
symbolic_traced.graph.print_tabular()
default_mapping = get_default_qconfig_mapping("qnnpack")
model_prepared = quantize_fx.prepare_fx(
symbolic_traced,
default_mapping,
dummy_input,
)
Error:
Traceback (most recent call last):
File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/Test.py", line 30, in <module>
symbolic_traced : torch.fx.GraphModule = symbolic_trace(TestGRU)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 1109, in symbolic_trace
graph = tracer.trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/home/ahmed/PulseAudition_PFE/pulse_ai/scripts/tiny/Test.py", line 19, in forward
self.states = self.rnn(x, self.states)
^^^^^^^^^^^
File "/home/ahmed/anaconda3/envs/tiny_pulse_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1665, in __setattr__
raise TypeError("cannot assign '{}' as buffer '{}' "
TypeError: cannot assign 'Proxy(add)' as buffer 'states' (torch.Tensor or None expected)
In both approaches, I think the problem is when passing āSELFā.states in the forward function.
Can you suggest any solution?
( Maybe try to externally update the states of the cell! )
Hope to hearing from you soon,
Ahmed