The main limitation here is that the parameters are “supposed” to remain parameters. So they cannot have history.
I think the simplest way to is to use the experimental API that you can find here: pytorch/_stateless.py at e4a9ee8d42449aced60660e9afdd5b8b1d0d29c5 · pytorch/pytorch · GitHub (you can install nightly build to be able to use it).
This would look like:
for t in range(OUT_MAX_ITR):
model.train()
# initially params are the actual parameters
params = model.named_parameters()
for i in range(IN_MAX_ITR):
outputs = _stateles.functional_call(model, params, xtr)
loss = compute_loss(outputs)
loss.backward(create_graph=True, inputs=params.values())
# Cannot use the plain optimizers here are they are not differentiable.
for k, p in params.items():
params[k] = p - p.grad
out_loss = function_of_params(params)
out_loss.backward()
# Here model.parameters() will have their .grad field populated.
# If you have any other leaf Tensor (not sure what theta is in your example)
# they will also get their .grad field populated.