How can I cast a pre-trained PyTorch model in Numpy, or, even better, JAX.Numpy?
(Context: I need to use a pre-trained PyTorch model as a deterministic transformation within Numpyro, which is based on JAX)
Here is a minimal example which I would appreciate help with:
import torch.nn as nn import jax.numpy as jnp model = nn.Linear(4, 4) z = jnp.array([5.,4.,3.,2.]) model(z)
It produces the error
TypeError: linear(): argument 'input' (position 1) must be Tensor, not DeviceArray.
z to a Torch Tensor is not an option since it causes issues with Numpyro. The only option (if I understand correctly) is to port the model’s architecture to JAX.numpy.
The description of my full problem is here: How to use a pre-trained PyToch model within Numpyro? - numpyro - Pyro Discussion Forum