How to cast a pre-trained PyTorch model in JAX.numpy?

How can I cast a pre-trained PyTorch model in Numpy, or, even better, JAX.Numpy?

(Context: I need to use a pre-trained PyTorch model as a deterministic transformation within Numpyro, which is based on JAX)

Here is a minimal example which I would appreciate help with:

import torch.nn as nn
import jax.numpy as jnp

model = nn.Linear(4, 4)
z =  jnp.array([5.,4.,3.,2.])
model(z)

It produces the error TypeError: linear(): argument 'input' (position 1) must be Tensor, not DeviceArray.

Converting z to a Torch Tensor is not an option since it causes issues with Numpyro. The only option (if I understand correctly) is to port the model’s architecture to JAX.numpy.

The description of my full problem is here: How to use a pre-trained PyToch model within Numpyro? - numpyro - Pyro Discussion Forum