I run into additional type checking frequently when I try to write a function that uses these math functions. For example: let’s say I want to have one utility function that computes Euclidean distance. Across my whole project, I might use it on tensors or Python scalars.
For Python scalars, I have to use something like:
import math
def euclidean_distance(a, b):
"""a and b are either Python scalars or tensors with size 1"""
return math.sqrt(a**2 + b**2)
This will work for all Python scalars and all torch tensors of size 1. But for PyTorch tensors, generally, I have to use:
import torch
def euclidean_distance(a, b):
"""a and b are arbitrary tensors"""
return torch.sqrt(a**2 + b**2)
Is there some way to allow for:
import torch
def euclidean_distance(a, b):
"""a and b are arbitrary tensors or Python scalars"""
return torch.sqrt(a**2 + b**2)
without uses having to check like:
import torch
def euclidean_distance(a, b):
"""a and b are arbitrary tensors or Python scalars"""
# Handle tensor vs not tensor inputs
a_not_tensor = False
if not torch.is_tensor(a):
a_not_tensor = True
a = torch.tensor([a])
b_not_tensor = False
if not torch.is_tensor(b):
b_not_tensor = True
b = torch.tensor([b])
# Compute the distance
dist = torch.sqrt(a**2 + b**2)
# If both inputs are Python scalars, return a Python scalar
if a_not_tensor and b_not_tensor:
return dist.item()
# If either input was a tensor, return a tensor
return dist
Is there a reason this functionality doesn’t exist? It’s not limited to torch.sqrt()
, but all PyTorch math functions. I imagine it’s easier to have torch.sqrt()
accept Python scalars then to automatically vectorize math.sqrt()
(a la Python build-ins like abs()
), but I suppose either would solve the problem.
[An aside: this is kind of related to the lack of a torch.pi
constant (1, 2, 3). Requiring import math
for certain things is fine to remove redundant constants, etc, but it would also be useful to provide the cross-functionality when needed.]