After you’ve stored the state_dicts
you could iterate the key
s of them and create a new state_dict
using the mean (or any other reduction) for all parameters.
This code snippet shows a small example:
# Setup
state_dicts = []
model = models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=1.)
# Train and store state_dict
for _ in range(5):
optimizer.zero_grad()
out = model(torch.randn(1, 3, 224, 224))
out.mean().backward()
optimizer.step()
state_dicts.append(copy.deepcopy(model.state_dict()))
# Create new state_dict with mean of all params
new_state_dict = collections.OrderedDict()
for key in model.state_dict():
if 'num_batches_tracked' in key: # handle this separately and reuse last value
param = state_dicts[-1][key]
else:
param = torch.mean(torch.stack([sd[key] for sd in state_dicts]), dim=0)
new_state_dict[key] = param
# Load into model
model.load_state_dict(new_state_dict)