Compile function that uses constant tensor


Imagine a setup like:

def construct(device, dtype) -> Tensor:
    return ... # some complicated code to construct a tensor with given device and dtype

def f(x):
    constant_tensor = construct(x.device, x.dtype)
    return x + constant_tensor

Now, I want to compile f, but I do not want to compile the construct function. Instead, I’d like the tensor constant_tensor to be a constant in the compiled graph. How do I best realize this?

Also, preferably, if my main module uses f multiple times, I’d prefer the constant to be computed/stored on the graph only once. Without compilation, I’d add a functools.lru_cache around construct to memoize the result, however, when I try to compile f with the cache, it tries to compile the caching function, which fails.


I think I’ve found a solution. We can make a lru_cache-like memoizer that appears like a dictionary to the compiler. It doesn’t trace the construction function, but instead thinks that all keys are simply present when compiling the function.

import torch

class Memoizer(dict):
    def __init__(self, fn):
        self.fn = fn

    def __missing__(self, item):
        tensor = self.fn(*item)
        self[item] = tensor
        return tensor

    def __call__(self, *args):
        return self[args]

def constant_constructor(dtype, device):
    print("construct contant")
    return torch.ones(1, dtype=dtype, device=device)

def f(x):
    cnst = constant_constructor(x.dtype, x.device)
    return x + cnst

f_compiled = torch.compile(f, fullgraph=True)
print(f_compiled(torch.ones(1)))  # -> construct constant
print(f_compiled(torch.ones(1).cuda()))  # -> construct constant