Any interest in a function for scripting strings?

Introduction

Hello everyone. Long time lurker, first time poster, etc.

I have a function I have developed for a torchscript preprocessing application
I am making that I think might be of interest to the broader community, and I
would just like to get feelers out there regarding whether you all are interested or not.

I have made a function that, bluntly, lets you execute code from a string in such a manner that you can then later script it. It is something of a hack: It functions by dumping the string into a temporary file which then acts as a tempory module - letting inspect find source when jit goes looking for it. If you script the resulting objects while the temporary module is around, they end up in jit’s cache and can be used elsewhere.

Example

As an example, lets say you define a function “foo” and “fo” in a string as follows:

def fo()->int:
    return 3

func_source="""\
def foo(item: torch.Tensor)->torch.Tensor:
    return item + fo()
"""

The function I have written will let you execute that string in order to get a live
function that is scriptable and usable like normal

live_func = exec_scriptable_string(func_source, "foo")
print(live_func(torch.Tensor([3])) # prints 6
scripted_func = torch.jit.script(live_func)
print(scripted_func(torch.Tensor([4])) # Prints 7

Note that this is not restricted to just functions, but should also work on things like classes

My Questions

Despite being a hack, I will observe torch currently lacks this functionality and I do not believe it would be very hard to integrate into the current codebase. So, with that, I have two questions:

  1. Is the community interested in adding this capability to torchscript?
  2. If yes, how do I make a pull request and update the documentation to get it into the main branch? I have not submitted a pull request before, and also want to make sure I update the documentation to include the new function.

Thank you for your time

That’s an interesting idea, but what’s the use case for it?
Given that your method is parsed from a string your IDE will most likely ignore it and not highlight any syntax errors etc.
Your approach generally reminds me of Python’s eval() method, which I would personally not want to execute using 3rd party code as it’s a security risk and hard to debug.

My particular use case is that I wanted to be able to use inheritance with non-module classes, and so I made a little program that does static analysis of a class by following the mro chain and synthesis a new “flattened” class with all the methods that are available to the original, using ast analysis. This code is then executed to produce a live class with all the methods available on the original. It is currently working as far as all my tests are concerned.

Frankly, if I knew where to put it, I would have preferred to just integrate this directly into torchscript and might still be able to do this, but I am not going to do so unless there is significant interest in that and would likely still need help knowing where to hook it up. I do, however, at the end of the process end up with an ast node describing the class which I have seen somewhere in torchscript’s source before so I do not think it would be too difficult?

Since my use case can be applied as a decorator, my IDE is still fairly happy. Notably, however, since I am not doing anything deep to torch’s type system isinstance does not work properly after this conversion.

Another use case I have seen at least once on the forums is that the person who implimented dataclasses into the main branch was bemoaning not being able to execute arbitrary strings and having to work around that limitation. I think they got around that by defining every single method manually?

And yes, it is designed at the moment to mimic eval, but return an object that is scriptable. This was to make it behave in a more familiar manner. However behind that is a context manager that loads the string into a temporary file which is then loaded as a module, and that module then is returned as the context load result.

Thanks for the detailed explanation!
Since TorchScript is in maintenance mode, I would not expect to see any new features being implemented, but maybe your approach could be valuable in the new Dynamo stack.
Feel free to create a feature request with your explanation (or a link to your post) on GitHub so that the code owners could take a look at it.

Interesting, I had not heard of Dynamo before this. I will have to poke around in it. Does it produce something that ONNX can interact with properly? What exactly is it suppose to do? I will look into including it into dynamo. Thanks for your help.

Also, I will go ahead and leave the code I have developed here in case anyone has this problem in the future. It is fairly well documented, but anyone who has questions may feel free to ask whatever they want.

"""

Torch does not normally like to script text
strings, which is a pity. This file helps to
fix that.

"""
import importlib
import io
import os
import pathlib
import sys
import tempfile
import inspect
from typing import Optional, Dict, Any, Union, List, Type, Callable, Tuple

import torch.jit
import torch

class StringScriptContext():
    """
    Introduction
    ############
    The class is a context manager which may be initialized with
    the source string to execute, along with environmental parameters
    such as local and global details. It can then be used to
    torchscript compile such strings as though they were defined in
    a module.
    It has several modes. One can independently compile
    and retrieve just a single entity using the get method.
    after initialization. Alternatively, one can open a context
    using a "with" statement, which will give the user full
    access to everything declared within the string. The user
    can then decide themselves what to compile and when.
    Sharp bits
    ##########
    Do note that the environmental compiler can cause side
    effects if you are not careful. While execution will not
    alter the locals and globals in your root environment,
    it MAY alter things within them.
    For example, this will produce a side effect:
    ```
    list_for_effect = []
    source = \"\"\"
    def cause_side_effect():
        list_for_side_effect.append(5)
    \"\"\"
    with StringScriptContext(source, globals(), locals()) as module:
        pass
    print(list_for_effect[0]) # Prints 5.
    ```
    Be aware thus that although the environment will not
    be directly modified by code in the string, the objects
    within it CAN be.
    Examples
    ########
    **Example: Scripting a single entity and retrieving the result**
    ```
    .. testcode::
        source = \"\"\"
        import torch
        def add_five(input: torch.Tensor)->torch.Tensor:
            return input + 5
        \"\"\"
        scripted_add_five = StringScriptContext(source).get("add_five")
        test_tensor = torch.tensor(0)
        print(scripted_add_five(test_tensor)) #Returns tensor(5)
    ```
    **Example: Scripting by opening a context manager
    ```
    .. testcode::
        source = \"\"\"
        def add_five(input: torch.Tensor)->torch.Tensor:
            return input + 5
        \"\"\"
        add_five = None
        with StringScriptContext(source) as module:
            add_five = torch.jit.script(module.add_five)
    ```
    **Example: Scripting with environmental dependencies**
    ```
    .. testcode::
        def redirect():
            return 4
        source = \"\"\"\
        def get_4():
            return redirect()
        \"\"\"
        with StringScriptContext(source, globals(), locals()) as module:
            get = torch.jit.script(module.get_4)
            print(get()) #Returns 4
    ```
    """
    #Debug info
    # Important things to note for maintainers and troubleshooters:
    #   This function edits sys.module. It will remove the edit later, but be careful here
    #   This func

    def exec_module(self, module):
        """
        NOT USER SERVICEABLE
        Compiles and executes the
        source code underlying the given module
        Modules, as entering this method, are
        literally a raw defined object. This
        method will compile and execute code
        from source, then transfer the novel
        results onto the new module.
        """

        execution_globals = self.globals.copy()
        execution_locals = self.locals.copy()

        source = open(module.__file__).read()
        code = compile(source, filename=module.__file__, mode="exec")
        exec(code, execution_globals, execution_locals)

        novel_globals = {key : value for key, value in execution_globals.items()
        if value not in self.globals.values()}
        novel_locals = {key : value for key, value in execution_locals.items()
        if value not in self.locals.values()}


        for key, value in novel_locals.items():
            setattr(module, key, value)
        for key, value in novel_globals.items():
            setattr(module, key, value)
        return module

    def get_handle(self)-> io.TextIOWrapper:
        """
        Gets a temporary file handle.
        Tries again if the suggested name has a collision in
        the system module attribute.
        """
        handle = None
        while handle is None:
            #Fetch a collision free module name
            _handle = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False)
            path = _handle.name
            name = pathlib.Path(path).stem
            if name not in sys.modules:
                handle = _handle
            else:
                _handle.close()
        return handle

    def get(self, name: str)->Union[torch.jit.ScriptModule, torch.jit.ScriptFunction]:
        """
        Get a particular compiled feature
        out of the source string. Accepts the name
        it was declared as in the original string.
        :param name: The name of the thing to compile
        :return: A script feature
        :raises: AssertionError:
        """
        output = None
        assert isinstance(name, str), "Name must be string"
        with self as module:
            output = getattr(module, name)
            output.__module__ = self.__module__ #Small hack to ensure qualified names resolves.
            output = torch.jit.script(output)
        return output

    def __init__(self,
                 source: str,
                 globals: Optional[Dict[str, Any]] = None,
                 locals: Optional[Dict[str, Any]] = None,
                 retain_file_on_error: bool = False,
                 ):
        """
        :param source: A string of source code to execute.
        :param globals: The globals in the environment. May be none, yielding a blank environment
        :param locals: The locals in the environment. May be none, yielding a blank environment.
        :param retain_file_on_error: Whether or not to retain a file so that the stack trace is easy
                to follow

        """

        if globals is None:
            globals = {}
        if locals is None:
            locals = {}

        self.globals = globals
        self.locals = locals
        self.source = source
        self.retain_file = retain_file_on_error

        self.path = None
        self.name = None
        self.module = None

    def __enter__(self):
        """
        Dumps the source string into a temporary file,
        creates an associated module in a global safe
        environment, then returns the new module
        :return: A active module with any generated features.
        """
        #This is the key function of the class, and
        #as such deserves some extra commentary
        #
        #The basic process for handling a source string
        #is to create a temporary file, dump everything
        #into it, create a temporary module backed by the file,
        #and execute everything, transferring it onto the module
        #
        #



        #Write to the temporary, then close it. It will not delete.
        handle = self.get_handle()
        path = handle.name
        handle.write(self.source)
        handle.close()

        #Form module
        name = pathlib.Path(path).stem
        spec = importlib.util.spec_from_file_location(name, handle.name)
        module = importlib.util.module_from_spec(spec)
        sys.modules[module.__name__] = module #This line is required for inspect to work

        try:
            self.exec_module(module)
        except Exception as err:
            if self.retain_file is False:
                os.remove(path)
            raise err

        self.path = path
        self.name = name
        self.module = module
        return module
    def __exit__(self, exc_type, exc_val, exc_tb):
        #TODO: Add some decent error handling.
        sys.modules.pop(self.name)
        if exc_val is None or self.retain_file is False:
            os.remove(self.path)
        self.path = None
        self.name = None

def get_context(frames_up: int) -> Tuple[dict, dict]:
    """
    Returns the global and local context at the specified number of frames up the stack.

    :param frames_up: The number of frames to look up the stack.
    :return: A tuple of the global and local context at the specified number of frames up the stack.
    """
    frame = inspect.currentframe()
    for i in range(frames_up + 1):
        frame = frame.f_back
    return frame.f_globals, frame.f_locals

def exec_scriptable_string(source: str,
                           feature: str)->Union[Type, Callable]:
    """
    A scripting agent designed to execute a source string
    in which is present scriptable code, and retrieve a particular
    object out of the source. The returned object is an unscripted instance,
    but will operate properly when torchscript goes to script it in
    the future.

    Note that the targetted feature must be scriptable if typed in
    normally for this function to work.

    :param source: The source code to work with
    :param feature: The name of the class or function to retrieve
    :return: The function or class, in it's unscripted format.
    """
    globals, locals = get_context(2)
    with StringScriptContext(source, globals, locals) as module:
        item = getattr(module, feature)
        torch.jit.script(item)
    return item
1 Like