Add constants to persistent state


#1

I have a short question regarding model constants and persistent state. Let’s say I have something like

class MyModule(nn.Module):
    def __init__(self, n=2):
        self.n = n

What’s the best way to make n part of the persistent state (i.e. the state_dict)? Should I make it a buffer? But then I would need to convert it into a tensor, which seems a bit of a hassle. Is there another more elegant way?


(Yun Chen) #2

I think you’d better make it a buffer by self.register_buffer('n', n)

    def state_dict(self, destination=None, prefix=''):
        """Returns a dictionary containing a whole state of the module.

        Both parameters and persistent buffers (e.g. running averages) are
        included. Keys are corresponding parameter and buffer names.

        Example:
            >>> module.state_dict().keys()
            ['bias', 'weight']
        """
        if destination is None:
            destination = OrderedDict()
        for name, param in self._parameters.items():
            if param is not None:
                destination[prefix + name] = param.data
        for name, buf in self._buffers.items():
            if buf is not None:
                destination[prefix + name] = buf
        for name, module in self._modules.items():
            if module is not None:
                module.state_dict(destination, prefix + name + '.')
        return destinatio