Converting a numpy dtype to torch dtype

I’d like to know the torch dtype that will result from applying torch.from_numpy(array) without actually calling this function. Since torch and numpy dtypes are incompatible (e.g. doing torch.zeros(some_shape, dtype=array.dtype) will yield an error), how can I do that?

1 Like

I’m not sure, if this method is exposed in the Python API, but you might create a custom dict mapping the numpy dtypes to PyTorch ones.

1 Like

Fun fact, we actually have such a dict in our test suite now: https://github.com/pytorch/pytorch/blob/e180ca652f8a38c479a3eff1080efe69cbc11621/torch/testing/_internal/common_utils.py#L349.

2 Likes