I have a short question regarding model constants and persistent state. Let’s say I have something like
class MyModule(nn.Module):
def __init__(self, n=2):
self.n = n
What’s the best way to make n part of the persistent state (i.e. the state_dict)? Should I make it a buffer? But then I would need to convert it into a tensor, which seems a bit of a hassle. Is there another more elegant way?
I think you’d better make it a buffer by self.register_buffer('n', n)
def state_dict(self, destination=None, prefix=''):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Example:
>>> module.state_dict().keys()
['bias', 'weight']
"""
if destination is None:
destination = OrderedDict()
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param.data
for name, buf in self._buffers.items():
if buf is not None:
destination[prefix + name] = buf
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + '.')
return destinatio