Accessing weights in a pretrained model fro handwritten text recognition

I am trying to access weights in a model that when loaded I find these keys:

dict_keys([‘optimizer_state_dict’, ‘epoch’, ‘scaler_state_dict’, ‘best’, ‘curriculum_config’, ‘encoder_state_dict’, ‘decoder_state_dict’])

This is a screenshot of he output I have:

I would really appreciate your help.
Thannk you in advance! :blush:

It seems you are loading a custom dict, which includes a few state_dicts. To check the weights (or any other parameter/buffer) of the models(s), you could use:

checkpoint = torch.load(...)
encoder_state_dict = checkpoint['encoder_state_dict']
print(encoder_state_dict[parameter_name])
# same for decoder_state_dict
1 Like

Thank you so much for your reply :pray:, I am now able to load the weights thanks to you.
One extra question, I did exactly what you told me and this is what I got:

I want to use this model to make a prediction for one image.
Could you please help me figure out how to do so?

You would need to create a model object first and load the state_dict afterwards:

model = MyModel()
model.load_state_dict(encoder_state_dict)

where MyModel is defined in your code somewhere as a custom nn.Module class.

1 Like

Thank you so much for your help ^^
What I did is the following:

features = encoder(image)
output = decoder(features)
output

and I got the following output:

tensor([[[[-1.0373e+01, -9.2096e+00, -5.9885e+00,  ..., -5.2428e+02,
           -3.8752e+02, -4.4228e+02],
          [-2.8680e+00, -3.1839e+00, -2.2322e+00,  ..., -5.3969e+02,
           -5.8710e+02, -6.5721e+02],
          [-5.8041e+00, -6.1079e+00, -5.0006e+00,  ..., -5.9030e+02,
           -6.4768e+02, -7.5395e+02],
          ...

Do you have an idea about getting the prediction that is a paragraph from this output?
PS: The CTCloss is then computed to make the paragraph transcription and I can’t seem to figure out how it must be loaded.

I’m not familiar with your use case, but based on the output it seems you might be working on a sequence prediction task? If so, is each sample supposed to predict a word/token etc.?
In case you are using another repository as your code base, I would probably check if this repo might already implement a predict method.