How to load state_dict if saved as torch.save(model.state_dict(), fs)

Did you try to serialize the state_dict of the original model as described here? CC @marksaroufim in case this workaround is not needed anymore.