Stateful vs. Stateless Models

Hi everyone. I want to implement a gradient-based Meta-Learning algorithm in PyTorch and I found out that there is a library called higher based on PyTorch that can be used to implement such algorithms where you have different steps of gradient descent in the inner loop of the algorithm. Therefore I decided to go through the paper published for the library here:

However, there is a couple of things in the paper that I don’t understand. For example, in the Obstacles section it mentions that models in PyTorch and Keras are stateful, where the model encapsulates the parameters. Can someone explain this intuitively by an example that what is the difference between stateful and stateless models, and how can you implement them in PyTorch?

Thanks :slight_smile:

Think of a trivial model such as y=mx+b where x is the input, y the output, and m and b are the weights and biases.

Ignoring PyTorch, you could implement a model as something like:

class Model():
   def __init__(self, weight, bias):
    def forward(self, input):
        return input * self.weight + self.bias

This would be “stateful” because the weights and biases are member variables, part of the “state” of the model class.

Alternatively you could write:

class Model():
    def forward(self, input, weight, bias):
        return input * weight + bias

This would be stateless, the model is parameterized. Any updates to the model parameters happen outside the model.

Moving from the former example to the second example is described in the paper as:

Here, it replaces the call to the forward method with one which first replaces the stateful parameters of the submodule with ones provided as additional arguments to the patched forward, before calling the original class’s bound forward method, which will now used the parameters provided at call time.

You can see this is an equivalent migration to the above example.

You may find information about stateful vs stateless LSTMs or RNNs when searching for more info, but that refers to a different kind of state from what’s referred to in this paper.


Thanks for your clear explanation!

So, could we say that we do the stateless way of forward-pass given parameters of a model, whenever we want to try different ways of optimization steps from the same starting parameter state (cloned for each individual path to go) so that they don’t “block” each other by their individual backward passes? or is there other important reasons for this way of forward pass?

I haven’t fully read the paper, but I’d think of the tradeoffs of stateful vs stateless programming as having more to do with computer science than ML. When the paper mentions “pure” functions this refers to the functional programming paradigm. In-place operations can be more memory efficient, while “pure” functions can be easier to reason about and refactor.