Weights are Not converging

After you’ve stored the state_dicts you could iterate the keys of them and create a new state_dict using the mean (or any other reduction) for all parameters.
This code snippet shows a small example:

# Setup
state_dicts = []
model = models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=1.)

# Train and store state_dict
for _ in range(5):
    optimizer.zero_grad()
    out = model(torch.randn(1, 3, 224, 224))
    out.mean().backward()
    optimizer.step()
    state_dicts.append(copy.deepcopy(model.state_dict()))

# Create new state_dict with mean of all params
new_state_dict = collections.OrderedDict()

for key in model.state_dict():
    if 'num_batches_tracked' in key: # handle this separately and reuse last value
        param = state_dicts[-1][key]
    else:
        param = torch.mean(torch.stack([sd[key] for sd in state_dicts]), dim=0)
    new_state_dict[key] = param  

# Load into model
model.load_state_dict(new_state_dict)
1 Like