# Torch.sqrt() and other math functions requiring tensor inputs

I run into additional type checking frequently when I try to write a function that uses these math functions. For example: let’s say I want to have one utility function that computes Euclidean distance. Across my whole project, I might use it on tensors or Python scalars.

For Python scalars, I have to use something like:

``````import math
def euclidean_distance(a, b):
"""a and b are either Python scalars or tensors with size 1"""
return math.sqrt(a**2 + b**2)
``````

This will work for all Python scalars and all torch tensors of size 1. But for PyTorch tensors, generally, I have to use:

``````import torch
def euclidean_distance(a, b):
"""a and b are arbitrary tensors"""
``````

Is there some way to allow for:

``````import torch
def euclidean_distance(a, b):
"""a and b are arbitrary tensors or Python scalars"""
``````

without uses having to check like:

``````import torch
def euclidean_distance(a, b):
"""a and b are arbitrary tensors or Python scalars"""
# Handle tensor vs not tensor inputs
a_not_tensor = False
if not torch.is_tensor(a):
a_not_tensor = True
a = torch.tensor([a])
b_not_tensor = False
if not torch.is_tensor(b):
b_not_tensor = True
b = torch.tensor([b])

# Compute the distance
dist = torch.sqrt(a**2 + b**2)

# If both inputs are Python scalars, return a Python scalar
if a_not_tensor and b_not_tensor:
return dist.item()

# If either input was a tensor, return a tensor
return dist
``````

Is there a reason this functionality doesn’t exist? It’s not limited to `torch.sqrt()`, but all PyTorch math functions. I imagine it’s easier to have `torch.sqrt()` accept Python scalars then to automatically vectorize `math.sqrt()` (a la Python build-ins like `abs()`), but I suppose either would solve the problem.

[An aside: this is kind of related to the lack of a `torch.pi` constant (1, 2, 3). Requiring `import math` for certain things is fine to remove redundant constants, etc, but it would also be useful to provide the cross-functionality when needed.]

2 Likes

I’m severely missing every single point you mention. Are there any issues on github related to them? Didn’t find any.