Indeed it is easy enough to manually change keys
In my case it was done by
for key in list(state_dict.keys()):
state_dict[key.replace('.1.', '1.'). replace('.2.', '2.')] = state_dict.pop(key)
It’s not generalizable, but if there is no one encountered this problem then it probably isn’t a big deal.