I am looking for something like the following in PyTorch
def jax_shape_get(input):
return jax.eval_shape(lambda x: x, input)
https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html
I am looking for something like the following in PyTorch
def jax_shape_get(input):
return jax.eval_shape(lambda x: x, input)
https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html