I’m working on a implementation of MAML (paper link). There’s a few implementations out there but from what can see they all rely on the functional form of a model. I have a requirement to make this as general as possible so it can be used with a variety of underlying models. To that end I would like to be able to copy a nn.Module instance in the same way one can copy a variable or parameter. The key is that the gradients need to be able to get back to the copied nn.Module.

I’ve tried to hack something together via state-dicts and hooks but I keep running into cases where I need to copy a nn.Module without removing it from the computational graph.

Edit: after reading through the other discussion I think its not what I am after. It focuses on how to properly copy a model and continue training with it. My use case is a little different as I want to properly clone the model and leave that cloning process in the computational graph so that I can differentiate all the way back to the original model.

tensor.clone() is a differentiable operation, which could be used on your parameters.
However, I’m not sure, if that would be suitable for your approach or not.

I’ve been trying to use the tensor.clone() operation but am not sure how to load the cloned tensors into a new nn.Module without using load_state_dict() which appears to disconnect parameters from the computation graph. I think you’re right that individually cloning the parameters is the way to go but I’m not sure how to pull this off and connect them to a new nn.Module without this detach from the computation graph occurring.

After playing around a little more I think it comes down to being able to do the following:

Let f and g both be instances of the same nn.Module. Throw some input values through g when g is in training mode. Now the parameters in g have gradients and are part of a computation graph. Let A be a parameter in f and B be the corresponding parameter in g. Now set C=B.clone(). Here’s the bit I don’t know how to do, reassign A to be C.

This seems like a use-case that should be possible with PyTorch without being too hacky. Anyone have a path forward?

Yes this is quite far from what nn.Modules were built for so it’s not super easy

The thing to keep in mind here is that the original Module and the cloned one should NOT be the same.
In particular, the cloned one cannot use nn.Parameter() because these can only contain leaf Tensors (that don’t have history).
So you most likely want to make sure that you delete each parameter and re-assign the field with a clone of the value in the other module.