def construct(device, dtype) -> Tensor:
return ... # some complicated code to construct a tensor with given device and dtype
def f(x):
constant_tensor = construct(x.device, x.dtype)
return x + constant_tensor
Now, I want to compile f, but I do not want to compile the construct function. Instead, I’d like the tensor constant_tensor to be a constant in the compiled graph. How do I best realize this?
Also, preferably, if my main module uses f multiple times, I’d prefer the constant to be computed/stored on the graph only once. Without compilation, I’d add a functools.lru_cache around construct to memoize the result, however, when I try to compile f with the cache, it tries to compile the caching function, which fails.
I think I’ve found a solution. We can make a lru_cache-like memoizer that appears like a dictionary to the compiler. It doesn’t trace the construction function, but instead thinks that all keys are simply present when compiling the function.