Global constants with torchscript

Ok, I know that global constants defined in the base of a python file don’t work with torch.jit.script(), but I’ve tried 4 or 5 seemingly reasonable ways to define a collection of constants for torchscript consumption and none of them work.

Does anyone have a good solution for project-wide constants used in a range of functions/modules across different files short of passing around and indexing a dict? That seems 1) inefficient as the constants can’t be compiled down as constants, 2) like bad code since it’s not obvious the dict shouldn’t be mutated. I’ve considered just hard coding constants across the codebase, but some are used in 20+ places and could actually change in the future despite not being a parameter, creating a real maintenance headache (think physical device measurements).

Surely there’s a best practice way to have constants available without a lot of object passing?? Any advice would be very much appreciated.

1 Like

It would indeed be really nice to have some good way to do this. I’m struggling with this now. It feels like I need to rather significantly restructure my code to make it work with torch.jit.script.

If I for example have a global (or some other kind of) constant float or tensor, could something like this exist:


x: torch.Tensor
y: float

@torch.jit.script
def func(n: int) -> :
    x_const = torch.jit.assume_constant(x)
    return x_const + x_const[n] + torch.jit.assume_constant(y)

?

2 Likes