Initialization of register_buffer for loading model

Hello,

I am in a situation were I initialize and train a model using some datasets. I then save the model state dict. I would like to be able to load the model without having to load the data again. However because of the dimensions mismatch when I call load_state_dict I don’t seem to be able to do that.

example :

class Model(nn.Module):
    def __init__(
            self,
            layer_list,
            x_mean: Optional[np.array] = None,
            x_std: Optional[np.array] = None,
            **_
    ):
        super(Model, self).__init__()

        self.register_buffer('x_mean', None if x_mean is None else torch.FloatTensor(x_mean))
        self.register_buffer('x_std', None if x_std is None else torch.FloatTensor(x_std))
        self.sequential = get_sequential(layer_list)

After training the x_mean, x_std, … are filled with tensor of certain shapes. However when I load the model, I ‘don’t know’ the shape of those, and I would like to avoid to search in the data what are the shape of x_mean and x_std.

In this case I would like to load that model, just using the layer_list and the function load_state_dict

Thanks!

The register_buffer operation includes a persistent argument, which defines if the buffer should be added to the state_dict (it’s set to True by default). If you don’t want to store x_mean and x_std in the state_dict you could thus set it to False for these buffers.

Actually, I would like to save it in the state dict. However when I initialize the model (before loading the state dict), I do not know the shape of those buffers. Therefore I got an error “shape mismatch” if I initialize it with a random tensor or just None if I initialize them with None. I would like to be able to load the model without explicitly have to say the shape of the buffers, so when it loads it can assign whatever tensors to those buffers.

Minimal working example:

import torch 
from typing import Optional

class Model(torch.nn.Module):
    def __init__(self,x: Optional[list] = None):
        super(Model, self).__init__()
        self.register_buffer('x', None if x is None else torch.FloatTensor(x))

PATH = 'state_dict.pt'
x = [1,  2]

model = Model(x)
torch.save(model.state_dict(), PATH)

model_loaded = Model()
model_loaded.load_state_dict(torch.load(PATH))
print(model_loaded.x)
1 Like

In your code snippet you are loading the state_dict into model, while I assume you want to load it into model_loaded, which will raise a RuntimeError, since no x buffer was registered:

import torch 
from typing import Optional

class Model(torch.nn.Module):
    def __init__(self,x: Optional[list] = None):
        super(Model, self).__init__()
        self.register_buffer('x', None if x is None else torch.FloatTensor(x))

PATH = 'state_dict.pt'
x = [1,  2]

model = Model(x)
sd = model.state_dict()

model_loaded = Model()
model_loaded.load_state_dict(sd)
> RuntimeError: Error(s) in loading state_dict for Model:
	Unexpected key(s) in state_dict: "x". 

The proper way would be to make sure to create this buffer, since you already know the shape from the loaded state_dict:

model_loaded = Model(sd['x'])
model_loaded.load_state_dict(sd)
> <All keys matched successfully>
2 Likes

Oh, ok
This is exactly what I needed !
Thank you :slight_smile:

(Yes I did a typo in the minimal working example)