Backpropagation with model ensembling after switching device

I am using ensemble of MLPs.
Following this tutorial I tried using vmap to vectorize the computations but I faced difficulties when changing device and backpropagating.
I checked this issue, however, it does not relate to device switching.

As far as I understand, switching device requires moving the stacked tensors to the new device, how can I do it without detaching them?

Thanks in advance,
Ori.

Could you describe you use case in more details and which issues you are seeing?
A minimal and executable code snippet reproducing errors would be great for debugging it further.