Torch.sqrt() and other math functions requiring tensor inputs

I run into additional type checking frequently when I try to write a function that uses these math functions. For example: let’s say I want to have one utility function that computes Euclidean distance. Across my whole project, I might use it on tensors or Python scalars.

For Python scalars, I have to use something like:

import math
def euclidean_distance(a, b):
    """a and b are either Python scalars or tensors with size 1"""
    return math.sqrt(a**2 + b**2)

This will work for all Python scalars and all torch tensors of size 1. But for PyTorch tensors, generally, I have to use:

import torch
def euclidean_distance(a, b):
    """a and b are arbitrary tensors"""
    return torch.sqrt(a**2 + b**2)

Is there some way to allow for:

import torch
def euclidean_distance(a, b):
    """a and b are arbitrary tensors or Python scalars"""
    return torch.sqrt(a**2 + b**2)

without uses having to check like:

import torch
def euclidean_distance(a, b):
    """a and b are arbitrary tensors or Python scalars"""
    # Handle tensor vs not tensor inputs
    a_not_tensor = False
    if not torch.is_tensor(a):
        a_not_tensor = True
        a = torch.tensor([a])
    b_not_tensor = False
    if not torch.is_tensor(b):
        b_not_tensor = True
        b = torch.tensor([b])

    # Compute the distance
    dist = torch.sqrt(a**2 + b**2)

    # If both inputs are Python scalars, return a Python scalar
    if a_not_tensor and b_not_tensor:
        return dist.item()

    # If either input was a tensor, return a tensor
    return dist

Is there a reason this functionality doesn’t exist? It’s not limited to torch.sqrt(), but all PyTorch math functions. I imagine it’s easier to have torch.sqrt() accept Python scalars then to automatically vectorize math.sqrt() (a la Python build-ins like abs()), but I suppose either would solve the problem.

[An aside: this is kind of related to the lack of a torch.pi constant (1, 2, 3). Requiring import math for certain things is fine to remove redundant constants, etc, but it would also be useful to provide the cross-functionality when needed.]

2 Likes

I’m severely missing every single point you mention. Are there any issues on github related to them? Didn’t find any.