How can I cast a pre-trained PyTorch model in Numpy, or, even better, JAX.Numpy?
(Context: I need to use a pre-trained PyTorch model as a deterministic transformation within Numpyro, which is based on JAX)
Here is a minimal example which I would appreciate help with:
import torch.nn as nn
import jax.numpy as jnp
model = nn.Linear(4, 4)
z = jnp.array([5.,4.,3.,2.])
model(z)
It produces the error TypeError: linear(): argument 'input' (position 1) must be Tensor, not DeviceArray
.
Converting z
to a Torch Tensor is not an option since it causes issues with Numpyro. The only option (if I understand correctly) is to port the model’s architecture to JAX.numpy.
The description of my full problem is here: How to use a pre-trained PyToch model within Numpyro? - numpyro - Pyro Discussion Forum