ParameterDict breaks torch.compile

As it says in the title the usage of ParameterDict inside a nn.Module breaks torch.compile as seen below

import torch
import torch.nn as nn

class Model(nn.Module):

    def __init__(self):
        super().__init__()

        self.parameter_dict = nn.ParameterDict({"foo": nn.Parameter(torch.zeros(1, 1, 128))})
        self.parameter = self.parameter_dict["foo"]

    def forward(self, x):
        return self.parameter_dict["foo"] + x # this breaks
        # return self.parameter + x  #this works fine


model = Model()
model = torch.compile(model, backend="eager")
x = torch.randn(1, 1, 128)

model(x)

Since it breaks with backend="eager" It’s probably a torch._dynamo issue. I don’t know if this is expected behaviour or not. Here is the stacktrace:

  File "test_compile.py", line 21, in <module>
    model(x)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 148, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 566, in call_function
    result = handler(tx, *args, **kwargs)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 790, in call_getitem
    return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
  File "/mnt/home/.pyenv/versions/geo/lib/python3.8/site-packages/torch/_dynamo/variables/nn_module.py", line 402, in call_method
    assert type(module).__getitem__ in (
AssertionError: ParameterDict

from user code:
   File "test_compile.py", line 13, in forward
    return self.parameter_dict["foo"] + x

Probably just a standard coverage problem - worth x-posting on github issues and tagging voz

1 Like