DasPantom
(Pan Kessel)
#1
Hello all,
is there some way to load a JAX array into a torch tensor? A naive way of doing this would be
import numpy as np
np_array = np.asarray(jax_array)
torch_ten = torch.from_numpy(np_array).cuda()
This would be slow as it would require me to move the jax array from the gpu to a cpu numpy array before loading it on the gpu again.
Just to be clear: I am not interested in any gradient information being preserved. Only the entries of the array need to be loaded.
Many thanks!
1 Like
albanD
(Alban D)
#2
How is jax handling GPU arrays?
If it uses things like cupy, you can check online how to convert it directly yo pytorch.
DasPantom
(Pan Kessel)
#3
Unfortunately, this seems to be not implemented yet. See
Many thanks for your reply, Pan
1 Like