AttributeError: 'numpy.ndarray' object has no attribute 'dim' from torch/nn/functional.py

Is this error raised directly after the print statement and did you check all state inputs?
Could some of them be numpy arrays somehow?