Loading Tensorflow checkpoints with Pytorch

I have some pre-existing code that uses Pytorch to interact with the generator from a trained GAN.

The problem is I have a new set of checkpoints I want to load but there are from a Tensorflow implementation of the model. Rather than going through all the code and changing it to work with TF, I’d rather just load the model as if it were Torch.

The examples I looked at online were different from what I have so I’m not sure how to go about doing this. Here’s what the folder with the TF checkpoints looks like:

Here’s how I’ve been loading other models with Torch:

# model name
gan_model_name = "gan_{}.tar".format(model_prefix)

# model checkpoints
checkpoint = torch.load(gan_model_name, map_location="cpu")

# model parameters are imported from params.py (don't change these)
wavegan_generator = WaveGANGenerator(
            slice_len=window_length,
            model_size=model_capacity_size,
            use_batch_norm=use_batchnorm,
            num_channels=num_channels,
        ).to(device)

# load saved checkpoints into initialized model
wavegan_generator.load_state_dict(checkpoint["generator"])
wavegan_generator.eval()

I know how to initiate the model with the correct parameters but loading the checkpoints into the state_dict is where I’m stuck.

I’ve seen examples of ONNX but they look different from what I have. Is there a way to load these TF model checkpoints using the same Torch code?

Thanks

You could transform the TF checkpoint to a PyTorch state_dict manually as @tom and I did for StyleGAN in this notebook.

2 Likes