Class referenced in another class fails to compile to onnx

Hello!
I am trying to convert a PyTorch model to the onnx format using torchscript functionality in Python. My code consists of multiple classes that reference and draw information from one another and the problem is that I do not know how to connect all of them without errors, thereby exporting the model. No matter what I do, I get stuck somewhere along the process. Here I have a minimal reproducible example of what it is that I am trying to do:

import numpy as np
import torch
from torch import autograd
from torch import nn
from torch import optim
import torch.nn.functional as F

import onnx
import onnxruntime
import torch.onnx


class BeamState:

  def __init__(self, source=None):
    if not source:
      self.mean_set = []

    else:
      self.mean_set = source.mean_set.copy()


  def append(self, mean, hidden, cluster):
    self.mean_set.append(mean.clone())


    
class Pred (torch.jit.ScriptModule):
    def __init__(self):
        super(Pred, self).__init__()
        self.bm = BeamState()
    @torch.jit.script_method
    def forward(self, x):
        beam_set = self.bm(3)
        prediction = x* beam_set
        return prediction


if __name__ == '__main__':
    batch_size = 1
    x = torch.randn(batch_size, 10)
    p_model = Pred()

    res = p_model(x)
    print("If you have reached this far, it works!", res)
    
    torch.onnx.export(p_model, x, "onnx_test.onnx", do_constant_folding=False, export_params=True, input_names = ['input'], output_names = ['output'],
                      example_outputs=torch.tensor([[1.0811, 1.0180, 1.0816, 1.1487, 1.1718, 1.3082, 0.8842, 0.9389, 1.3681,
         1.2647]], dtype=torch.float64), dynamic_axes={'input' : {0 : 'batch_size', 1:'utterance_size'}})
    print("onnx model exported")

This produces the following error message:


Traceback (most recent call last):
  File "C:\Users\User\Python\Projects\onnx_tester.py", line 50, in <module>
    p_model = Pred()
  File "C:\Users\User\Python\lib\site-packages\torch\jit\_script.py", line 210, in init_then_script
    ] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)
  File "C:\Users\User\Python\lib\site-packages\torch\jit\_recursive.py", line 352, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "C:\Users\User\Python\lib\site-packages\torch\jit\_recursive.py", line 410, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "C:\Users\User\Python\lib\site-packages\torch\jit\_recursive.py", line 304, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: 
Module 'Pred' has no attribute 'bm' (This attribute exists on the Python module, but we failed to convert Python type: '__main__.BeamState' to a TorchScript type.):
  File "C:\Users\User\Python\Projects\onnx_tester.py", line 42
    @torch.jit.script_method
    def forward(self, x):
        beam_set = self.bm(3)
                   ~~~~~~~ <--- HERE
        prediction = x* beam_set
        return prediction

I tried minor modifications, but every time I get a different error message and I do not know where to begin. How do we go about turning such an example to onnx?