How do I see the shape and dtype of a datastructure holding pytorch tensors and normal scalars?

I am looking for something like the following in PyTorch

def jax_shape_get(input):
    return jax.eval_shape(lambda x: x, input)

https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html

import jax

def torch_shape_get(input):
    def h_shape_get(x):
        return x.dtype, x.shape

    jax.tree_map(h_shape_get, input)