How back-propagation works with weights from cloned tensors

Dear Community, I’m trying to understand why the following meta-learning pseudo-code works. Could you please give me some guidance?

param: dict[str, torch.Tensor]
optimizer = Adam(params=param)
def inner_loop(parameter, data):
  cloned_param = clone parameter
  calculate something with cloned_param (using data)
  get the loss from said calculation
  gradients = autograd.grad(output=that loss, input=cloned_parameter.values)
  use gradients to update cloned_param
  return: the updated cloned_param
def outer_loop(data):
  adapted_parameters = inner_loop(param, part 1 of data)
  loss = calculate(adapted_parameters, part 2 of data)
  optimizer.zero_grad()
  [Q] loss.backward()
  optimizer.step()

[Q] Why would this (and the optimizer) work? The loss comes from the adapted_parameters, which is cloned then updated in inner_loop, not the original parameters (‘param’)
My understanding is that we need to use param itself to calculate stuff and get the loss if we want to .backward to update the param itself. It’s confusing to me that we are using loss from (virtually) a ‘future’ version of param to update param, and that it works
[Q2] What would be a nice way to organize parameters besides dictionary with string names? For models with lots of layers/steps, errors from typing names (of parameter tensors) wrongly has been frustrating

Thanks a lot!

Hi Gears!

I have no idea what you are trying to do here, but the following comment
might be helpful:

Backpropagation through a cloning operation works just fine. Consider:

>>> import torch
>>> torch.__version__
'2.1.2'
>>> param = torch.ones (1, requires_grad = True)
>>> opt = torch.optim.SGD ([param], lr = 0.1)
>>> pclone = param.clone()
>>> pclone
tensor([1.], grad_fn=<CloneBackward0>)
>>> loss = pclone**2
>>> loss.backward()
>>> param.grad
tensor([2.])
>>> opt.step()
>>> param
tensor([0.8000], requires_grad=True)

(Note if we had used pclone = param.detach().clone(), backpropagation
back to param would not have occurred, loss.backward() would not
have populated param.grad, and opt.step() would not have modified
param.)

Best.

K. Frank

Thanks very much for your help K. Frank! I re-read my instructions and it turns out I misread some of it.