Save dictionary into state_dict

So, I wrote a custom extension of BatchNorm where I keep track of multiple running_mean and running_vars in a dictionary. The number of such mean and vars is not specified beforehand (e.g. can be 3 or 100), so I can’t replace them with a fixed number of parameters. The dictionary method works well for me except I cannot register my running_means and running_vars dictionary into the state dict for later reloading. Can someone suggest a workaround?

You can use self.register_buffer(param_name, initial_value) to add these tensors in the model’s state dict.

This snippet shows how to do it:

def __init__(self):
    self.running_means = {}
        for i in range(100):
            value = torch.randn(5)
            self.running_means[f'running_mean{i}'] = value
            self.register_buffer(f'running_mean{i}', value)

I tried this but the running_mean{i} in the buffer does not get updated for some reason when I update the self.running_means[i]. In fact, I registered self.running_means[i] into the buffer directly to make sure the buffer points to the dictionary element directly, but that didn’t help either.

If you update them by reassigning the value like this self.running_means[i] = torch.tensor(...), the buffers won’t get updated, but if you change these tensors in-place then it should work, e.g. self.running_means[i].add_(...).

So, I should do:

self.running_means[i].add_(-self.running_means[i] + new_val)

to set self.running_means[i] = new_val.

To set self.running_means[i] to new_val, you can do it in-place like this: self.running_means[i].data = new_val.

If you reassign the value with self.runnng_means[i] = new_val, the dict key points to the new value that is not stored in the buffer (the buffer has the old value).

Worked! Thanks a lot.