Hello, can someone help? I am trying to implement resuming training mechanism, all my experiments can be reproducing easily if start training from scratch. But now, I needed resume training to be 100% reproduce. But I am facing problems. See code below
def compare_state_dicts(self, state_dict1, state_dict2, tol=1e-6):
differences = {}
# Check if both have the same keys
keys1 = set(state_dict1.keys())
keys2 = set(state_dict2.keys())
# Check for any keys present in one state_dict but not the other
missing_keys = keys1.symmetric_difference(keys2)
if missing_keys:
differences['missing_keys'] = list(missing_keys)
# Compare each parameter/buffer
for key in keys1.intersection(keys2):
param1 = state_dict1[key]
param2 = state_dict2[key]
# Check shape difference
if param1.shape != param2.shape:
differences[key] = {
'type': 'shape_mismatch',
'shape1': param1.shape,
'shape2': param2.shape
}
else:
# Check value difference within a tolerance
if (param1 - param2).sum() != 0:
max_diff = torch.max(torch.abs(param1 - param2)).item()
differences[key] = {
'type': 'value_mismatch',
'max_diff': max_diff
}
return differences
def train_step(self, batch, epoch: int):
if self.input_mask:
inputs, targets, input_lengths, target_lengths, mask = batch
inputs, targets, input_lengths, target_lengths, mask = (
inputs.to(self.device),
targets.to(self.device),
input_lengths.to(self.device),
target_lengths.to(self.device),
mask.to(self.device),
)
else:
inputs, targets, input_lengths, target_lengths = batch
inputs, targets, input_lengths, target_lengths = (
inputs.to(self.device),
targets.to(self.device),
input_lengths.to(self.device),
target_lengths.to(self.device),
)
if epoch == 1:
checkpoint = torch.load('/tmp/last.pt')
dif = self.compare_state_dicts(self.model.state_dict(), checkpoint['model'])
print(dif)
outputs = self.model(inputs)
print(outputs)
self.model.load_state_dict(checkpoint['model'])
self.model.train()
dif = self.compare_state_dicts(self.model.state_dict(), checkpoint['model'])
print(dif)
outputs = self.model(inputs)
print(outputs)
exit()
Both time dif is {}, but outputs slightly differ, that hearts reproducibility:
tensor([[[ 1.4755, -1.9476, -4.6788, …, -3.0517, -2.3541, -1.6889],
tensor([[[ 1.4753, -1.9451, -4.6914, …, -3.0670, -2.3551, -1.6928],
What could be the problem here?