I am trying to access weights in a model that when loaded I find these keys:
dict_keys([‘optimizer_state_dict’, ‘epoch’, ‘scaler_state_dict’, ‘best’, ‘curriculum_config’, ‘encoder_state_dict’, ‘decoder_state_dict’])
This is a screenshot of he output I have:
I would really appreciate your help.
Thannk you in advance! 
It seems you are loading a custom dict
, which includes a few state_dict
s. To check the weights (or any other parameter/buffer) of the models(s), you could use:
checkpoint = torch.load(...)
encoder_state_dict = checkpoint['encoder_state_dict']
print(encoder_state_dict[parameter_name])
# same for decoder_state_dict
1 Like
Thank you so much for your reply
, I am now able to load the weights thanks to you.
One extra question, I did exactly what you told me and this is what I got:
I want to use this model to make a prediction for one image.
Could you please help me figure out how to do so?
You would need to create a model object first and load the state_dict
afterwards:
model = MyModel()
model.load_state_dict(encoder_state_dict)
where MyModel
is defined in your code somewhere as a custom nn.Module
class.
1 Like
Thank you so much for your help ^^
What I did is the following:
features = encoder(image)
output = decoder(features)
output
and I got the following output:
tensor([[[[-1.0373e+01, -9.2096e+00, -5.9885e+00, ..., -5.2428e+02,
-3.8752e+02, -4.4228e+02],
[-2.8680e+00, -3.1839e+00, -2.2322e+00, ..., -5.3969e+02,
-5.8710e+02, -6.5721e+02],
[-5.8041e+00, -6.1079e+00, -5.0006e+00, ..., -5.9030e+02,
-6.4768e+02, -7.5395e+02],
...
Do you have an idea about getting the prediction that is a paragraph from this output?
PS: The CTCloss is then computed to make the paragraph transcription and I can’t seem to figure out how it must be loaded.
I’m not familiar with your use case, but based on the output it seems you might be working on a sequence prediction task? If so, is each sample supposed to predict a word/token etc.?
In case you are using another repository as your code base, I would probably check if this repo might already implement a predict
method.