Convert jax array to torch tensor

Hello all,

is there some way to load a JAX array into a torch tensor? A naive way of doing this would be

import numpy as np

np_array = np.asarray(jax_array)
torch_ten = torch.from_numpy(np_array).cuda()

This would be slow as it would require me to move the jax array from the gpu to a cpu numpy array before loading it on the gpu again.

Just to be clear: I am not interested in any gradient information being preserved. Only the entries of the array need to be loaded.

Many thanks!

How is jax handling GPU arrays?
If it uses things like cupy, you can check online how to convert it directly yo pytorch.

Unfortunately, this seems to be not implemented yet. See

Many thanks for your reply, Pan

