Is there an easy way to consistently traverse the optimiser states?
What I want to do is copy the optim states on gpu into a big flat tensor on cpu.
In order to do this, I need to have an iterator that traverses the optimiser states in the same order everytime.
My current solution is to use [optimizer.state[i] for i in optimizer.param_groups[0]['params']
numels = sum(i.numel() for i in optimizer.param_groups[0]['params'])
flat_tensor = torch.empty((2 * numels,))
offset = 0
for i in optimizer.param_groups[0]['params']:
for state_key in ['exp_avg', 'exp_avg_sq']:
_target = optimizer.state[i][state_key]
_temp_view = flat_tensor.as_strided(_target.shape, _target.stride(), offset)
_temp_view.copy_(_target, non_blocking=True)
offset += _target.numel()