How can I cast a pre-trained PyTorch model in Numpy, or, even better, JAX.Numpy?

(Context: I need to use a pre-trained PyTorch model as a deterministic transformation within Numpyro, which is based on JAX)

Here is a minimal example which I would appreciate help with:

```
import torch.nn as nn
import jax.numpy as jnp
model = nn.Linear(4, 4)
z = jnp.array([5.,4.,3.,2.])
model(z)
```

It produces the error `TypeError: linear(): argument 'input' (position 1) must be Tensor, not DeviceArray`

.

Converting `z`

to a Torch Tensor is not an option since it causes issues with Numpyro. The only option (if I understand correctly) is to port the modelâ€™s architecture to JAX.numpy.

The description of my full problem is here: How to use a pre-trained PyToch model within Numpyro? - numpyro - Pyro Discussion Forum