Torch distributions and yaml

I’m trying to use yaml file to configure which distributions to use in my project. In order to minimize how much parsing I should write, I use pyyaml tags. My document looks like this:

test: !!python/object:torch.distributions.Normal
  loc: 42.0
  scale: 1.0

Loading this document works:

import yaml
test = yaml.load(open("test.yaml"), Loader=yaml.Loader)

However, I can’t use the object:

test
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/myuser/conda/envs/test/lib/python3.8/site-packages/torch/distributions/distribution.py", line 334, in __repr__
    [
  File "/home/myuser/conda/envs/test/lib/python3.8/site-packages/torch/distributions/distribution.py", line 335, in <listcomp>
    f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}"
AttributeError: 'float' object has no attribute 'numel'

It sounds like the object is not properly initialized.

Any idea what’s going wrong ?

I found a solution using the following yam file:

test: !!python/object/apply:torch.distributions.Normal
  - 42.0
  - 1.0

Note the /apply at the end and the syntax to pass positional arguments.