How to reacreate a single image with trained model?

Hello!

I have trained a CCGAN model and saved the generator and discriminator using

torch.save(gen.state_dict(), "GENERATOR/gen.pt")
torch.save(disc.state_dict(), "DISCRIMINATOR/disc.pt")

I now wish to test this model on a single image. (I have trained several models using several slightly different custom datasets, and I wish to see which dataset is the most fitting). How do I go about presenting the algorithm a single image in order to test the output? Is there a tutorial on this? Any tips are welcome.

Tank you!

From what I understand you want to pass a single image through a model.

Models usually take inputs for the forward method in the form of BxCxHxW (Batch x Channel x Height x Width).

So if you have a single image with the form CxHxW you need to add the B dimension through unsqueeze(0).

output = model(img.unsqueeze(0))

Hope this helps :slight_smile:

1 Like

I guess I’m supposed to run the image just through the generator part, but I’m unsure how to properly load the generator model and the saved states.

I’ll give it a try! Thanks for your help :smiley:

To save and load you can just use this.

image
Taken from here

Hmm, using

import torch
import torch.nn as nn
import torch.optim as optim
import tifffile as tiff
from generator import Generator

model = Generator(features_g = 64, num_channels = 3)

model.load_state_dict(torch.load('gen.pt'))

img = tiff.imread('test.tiff')
output = model(img.unsqueeze(0))

I get

RuntimeError: Error(s) in loading state_dict for Generator:
	Unexpected key(s) in state_dict: "down1.model.0.weight", 
"down2.model.0.weight", "down2.model.1.weight", "down2.model.1.bias", "down2.model.1.running_mean", "down2.model.1.running_var",...

Any idea?

Can you print the state dict that your model actually has and the one that is being loaded to see what is expected?

for param in torch.load('gen.pt'):
    print(param)
for param in model.state_dict():
    print(param)

Also, you should have the image inside a tensor before passing it to the model.

1 Like

Will keep in mind about putting the image in tensor.

for param in torch.load('gen.pt'):
    print(param)

prints out

down1.model.0.weight
down2.model.0.weight
down2.model.1.weight
down2.model.1.bias
down2.model.1.running_mean
down2.model.1.running_var
down2.model.1.num_batches_tracked
down3.model.0.weight
down3.model.1.weight
down3.model.1.bias
down3.model.1.running_mean
down3.model.1.running_var
down3.model.1.num_batches_tracked
down4.model.0.weight
down4.model.1.weight
down4.model.1.bias
down4.model.1.running_mean
down4.model.1.running_var
down4.model.1.num_batches_tracked
down5.model.0.weight
down5.model.1.weight
down5.model.1.bias
...

same keys as in the error.

Could you also post the output for the model.state_dict?

Huh, there’s nothing in the print. I don’t understand…

The reason it was empty is because I didn’t include in the model self.build() in __init__.

Do the parameters now match, or are they still different?

Everything seems in order now. Hopefully the generator works properly. Thanks :smiley:

1 Like