Converting a numpy dtype to torch dtype

I’d like to know the torch dtype that will result from applying torch.from_numpy(array) without actually calling this function. Since torch and numpy dtypes are incompatible (e.g. doing torch.zeros(some_shape, dtype=array.dtype) will yield an error), how can I do that?

1 Like

I’m not sure, if this method is exposed in the Python API, but you might create a custom dict mapping the numpy dtypes to PyTorch ones.

1 Like

Fun fact, we actually have such a dict in our test suite now: