Disable "in-place" updates in troch.nn

The main limitation here is that the parameters are “supposed” to remain parameters. So they cannot have history.

I think the simplest way to is to use the experimental API that you can find here: pytorch/_stateless.py at e4a9ee8d42449aced60660e9afdd5b8b1d0d29c5 · pytorch/pytorch · GitHub (you can install nightly build to be able to use it).

This would look like:

for t in range(OUT_MAX_ITR):
    model.train()
    # initially params are the actual parameters
    params = model.named_parameters()

    for i in range(IN_MAX_ITR):
        outputs = _stateles.functional_call(model, params, xtr)
        loss = compute_loss(outputs)
        loss.backward(create_graph=True, inputs=params.values())
        # Cannot use the plain optimizers here are they are not differentiable.
        for k, p in params.items():
            params[k] = p - p.grad
 
    out_loss = function_of_params(params)
    out_loss.backward()
    # Here model.parameters() will have their .grad field populated.
    # If you have any other leaf Tensor (not sure what theta is in your example)
   # they will also get their .grad field populated.