diff --git a/CHANGES.rst b/CHANGES.rst index 537c5a8..d0609a4 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,7 +4,7 @@ Changes 8.3 (unreleased) ---------------- -- Nothing changed yet. +- Allow ``ast.Module``, ``ast.Expression`` and ``ast.Interactive`` as body in compile_restricted_function 8.3a1.dev0 (2026-05-29) diff --git a/docs/usage/api.rst b/docs/usage/api.rst index 8e8ca90..59eaddd 100644 --- a/docs/usage/api.rst +++ b/docs/usage/api.rst @@ -67,10 +67,10 @@ API overview :param flags: (optional). defaults to ``0`` :param dont_inherit: (optional). defaults to ``False`` :param policy: (optional). defaults to ``RestrictingNodeTransformer`` - :type p: str or unicode text - :type body: str or unicode text - :type name: str or unicode text - :type filename: str or unicode text + :type p: str + :type body: str or bytes or bytearray or ``ast.Module`` or ``ast.Expression`` or ``ast.Interactive`` + :type name: str + :type filename: str or bytes or os.PathLike[typing.Any] :type globalize: None or list :type flags: int :type dont_inherit: int diff --git a/src/RestrictedPython/compile.py b/src/RestrictedPython/compile.py index 3253b8c..2d0513a 100644 --- a/src/RestrictedPython/compile.py +++ b/src/RestrictedPython/compile.py @@ -4,6 +4,7 @@ from RestrictedPython._compat import IS_CPYTHON from RestrictedPython.transformer import RestrictingNodeTransformer +from RestrictedPython.transformer import copy_locations CompileResult = namedtuple( @@ -140,16 +141,23 @@ def compile_restricted_function( http://restrictedpython.readthedocs.io/en/latest/usage/index.html#RestrictedPython.compile_restricted_function """ # Parse the parameters and body, then combine them. - try: - body_ast = ast.parse(body, '', 'exec') - except SyntaxError as v: - error = syntax_error_template.format( - lineno=v.lineno, - type=v.__class__.__name__, - msg=v.msg, - statement=v.text.strip() if v.text else None) - return CompileResult( - code=None, errors=(error,), warnings=(), used_names=()) + if isinstance(body, ast.Expression): + _body_ast = ast.Expr(body.body) + copy_locations(_body_ast, body.body) + body_ast = [_body_ast] + elif isinstance(body, (ast.Module, ast.Interactive)): + body_ast = body.body + else: + try: + body_ast = ast.parse(body, '', 'exec').body + except SyntaxError as v: + error = syntax_error_template.format( + lineno=v.lineno, + type=v.__class__.__name__, + msg=v.msg, + statement=v.text.strip() if v.text else None) + return CompileResult( + code=None, errors=(error,), warnings=(), used_names=()) # The compiled code is actually executed inside a function # (that is called when the code is called) so reading and assigning to a @@ -157,7 +165,7 @@ def compile_restricted_function( # UnboundLocalError. # We don't want the user to need to understand this. if globalize: - body_ast.body.insert(0, ast.Global(globalize)) + body_ast.insert(0, ast.Global(globalize)) wrapper_ast = ast.parse('def masked_function_name(%s): pass' % p, '', 'exec') # In case the name you chose for your generated function is not a @@ -166,7 +174,7 @@ def compile_restricted_function( assert isinstance(function_ast, ast.FunctionDef) function_ast.name = name - wrapper_ast.body[0].body = body_ast.body + wrapper_ast.body[0].body = body_ast wrapper_ast = ast.fix_missing_locations(wrapper_ast) result = _compile_restricted_mode( diff --git a/tests/test_compile_restricted_function.py b/tests/test_compile_restricted_function.py index d1454db..b282ad4 100644 --- a/tests/test_compile_restricted_function.py +++ b/tests/test_compile_restricted_function.py @@ -1,3 +1,4 @@ +import ast from types import FunctionType from RestrictedPython import PrintCollector @@ -233,3 +234,68 @@ def test_compile_restricted_function_invalid_syntax(): assert error_msg.startswith( "Line 1: SyntaxError: cannot assign to literal here. Maybe " ) + + +def test_compile_restricted_function_pre_parse_exec(): + p = '' + body = ast.parse(""" +print("Hello World!") +return printed +""") + name = "hello_world" + global_symbols = [] + + result = compile_restricted_function( + p, # parameters + body, + name, + filename='', + globalize=global_symbols + ) + + assert result.code is not None + assert result.errors == () + + safe_globals = { + '__name__': 'script', + '_getattr_': getattr, + '_print_': PrintCollector, + '__builtins__': safe_builtins, + } + safe_locals = {} + exec(result.code, safe_globals, safe_locals) + hello_world = safe_locals['hello_world'] + assert type(hello_world) is FunctionType + assert hello_world() == 'Hello World!\n' + + +def test_compile_restricted_function_pre_parse_single(): + p = '' + body = ast.parse(""" +return "Hello World!" +""", mode="single") + name = "hello_world" + global_symbols = [] + + result = compile_restricted_function( + p, # parameters + body, + name, + filename='', + globalize=global_symbols + ) + + assert result.code is not None + assert result.errors == () + + safe_globals = { + '__name__': 'script', + '_getattr_': getattr, + '_print_': PrintCollector, + '__builtins__': safe_builtins, + } + safe_locals = {} + exec(result.code, safe_globals, safe_locals) + hello_world = safe_locals['hello_world'] + assert type(hello_world) is FunctionType + assert hello_world() == 'Hello World!'