Convert ProGAN agent from pth to onnx

I trained a ProGAN agent using this PyTorch reimplementation, and I saved the agent as a .pth. Now I need to convert the agent into the .onnx format, which I am doing using this scipt:

from torch.autograd import Variable

import torch.onnx
import torchvision
import torch

device = torch.device("cuda")

dummy_input = torch.randn(1, 3, 64, 64)
state_dict = torch.load("GAN_agent.pth", map_location = device)

torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")

Once I run it, I get the error AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict' (full prompt below). As far as I understood, the problem is that converting the agent into .onnx requires more information. Am I missing something?

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-2-c64481d4eddd> in <module>
     10 state_dict = torch.load("GAN_agent.pth", map_location = device)
     11 
---> 12 torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
    146                         operator_export_type, opset_version, _retain_param_name,
    147                         do_constant_folding, example_outputs,
--> 148                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
    149 
    150 

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
     64             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
     65             example_outputs=example_outputs, strip_doc_string=strip_doc_string,
---> 66             dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
     67 
     68 

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size)
    414                                                         example_outputs, propagate,
    415                                                         _retain_param_name, do_constant_folding,
--> 416                                                         fixed_batch_size=fixed_batch_size)
    417 
    418         # TODO: Don't allocate a in-memory string for the protobuf

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size)
    277             model.graph, tuple(in_vars), False, propagate)
    278     else:
--> 279         graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
    280         state_dict = _unique_state_dict(model)
    281         params = list(state_dict.values())

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _trace_and_get_graph_from_model(model, args, training)
    226     # A basic sanity check: make sure the state_dict keys are the same
    227     # before and after running the model.  Fail fast!
--> 228     orig_state_dict_keys = _unique_state_dict(model).keys()
    229 
    230     # By default, training=False, which is good because running a model in

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\jit\__init__.py in _unique_state_dict(module, keep_vars)
    283     # id(v) doesn't work with it. So we always get the Parameter or Buffer
    284     # as values, and deduplicate the params using Parameters and Buffers
--> 285     state_dict = module.state_dict(keep_vars=True)
    286     filtered_dict = type(state_dict)()
    287     seen_ids = set()

AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'

torch.onnx.export expects the model as the first argument, while it seems you are passing the state_dict to it,
Could you create an instance of the model, load the state_dict via model.load_state_dict(state_dict), and try to export it afterwards?

Done! Problem solved, more details here. Thanks for the help.