Use a different state_dict for inference

Hi everyone!

I remember that some time ago, there was an announcement (I can’t find the source now) of a new feature. This feature would allow for a given model to realize inference using a previously saved state_dict, rather than using the weights of the model being trained at the moment.

The example that was shown in the feature announcement was something like this:

import torch
import torch.nn as nn

from .models import MyModel
## ...

model_instance = MyModel()

# do some model training

previously_saved_state_dict = torch.load("path/to/pretrained/model.ckpt")

pseudolabels = model_instance.inference_with_dict_state(previously_saved_state_dict, input)

predictions = model_instance(input)

loss = my_loss_function(predictions, pseudolabels)
# ...

I can not find anything like this in the documentation, and Googling has not returned anything useful so far.
Was this a beta feature that was finally not supported? Is it somewhere hidden in the documentation? Or did I just dream this feature?

Thanks in advance! :slight_smile:


This doesn’t ring a bell I’m afraid.
The two things I would say that have nothing to do with inference are:

  • You can just load_state_dict this state dict to use it as usual
  • There is `torch.func.functional_call() that can be used to evaluate the model using a given set of params/buffers instead of the ones contained in the model itself.