Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changes
8.3 (unreleased)
----------------

- Nothing changed yet.
- Allow ``ast.Module``, ``ast.Expression`` and ``ast.Interactive`` as body in compile_restricted_function


8.3a1.dev0 (2026-05-29)
Expand Down
8 changes: 4 additions & 4 deletions docs/usage/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ API overview
:param flags: (optional). defaults to ``0``
:param dont_inherit: (optional). defaults to ``False``
:param policy: (optional). defaults to ``RestrictingNodeTransformer``
:type p: str or unicode text
:type body: str or unicode text
:type name: str or unicode text
:type filename: str or unicode text
:type p: str
:type body: str or bytes or bytearray or ``ast.Module`` or ``ast.Expression`` or ``ast.Interactive``
:type name: str
:type filename: str or bytes or os.PathLike[typing.Any]
:type globalize: None or list
:type flags: int
:type dont_inherit: int
Expand Down
32 changes: 20 additions & 12 deletions src/RestrictedPython/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from RestrictedPython._compat import IS_CPYTHON
from RestrictedPython.transformer import RestrictingNodeTransformer
from RestrictedPython.transformer import copy_locations


CompileResult = namedtuple(
Expand Down Expand Up @@ -140,24 +141,31 @@ def compile_restricted_function(
http://restrictedpython.readthedocs.io/en/latest/usage/index.html#RestrictedPython.compile_restricted_function
"""
# Parse the parameters and body, then combine them.
try:
body_ast = ast.parse(body, '<func code>', 'exec')
except SyntaxError as v:
error = syntax_error_template.format(
lineno=v.lineno,
type=v.__class__.__name__,
msg=v.msg,
statement=v.text.strip() if v.text else None)
return CompileResult(
code=None, errors=(error,), warnings=(), used_names=())
if isinstance(body, ast.Expression):
_body_ast = ast.Expr(body.body)
copy_locations(_body_ast, body.body)
body_ast = [_body_ast]
elif isinstance(body, (ast.Module, ast.Interactive)):
body_ast = body.body
else:
try:
body_ast = ast.parse(body, '<func code>', 'exec').body
except SyntaxError as v:
error = syntax_error_template.format(
lineno=v.lineno,
type=v.__class__.__name__,
msg=v.msg,
statement=v.text.strip() if v.text else None)
return CompileResult(
code=None, errors=(error,), warnings=(), used_names=())

# The compiled code is actually executed inside a function
# (that is called when the code is called) so reading and assigning to a
# global variable like this`printed += 'foo'` would throw an
# UnboundLocalError.
# We don't want the user to need to understand this.
if globalize:
body_ast.body.insert(0, ast.Global(globalize))
body_ast.insert(0, ast.Global(globalize))
wrapper_ast = ast.parse('def masked_function_name(%s): pass' % p,
'<func wrapper>', 'exec')
# In case the name you chose for your generated function is not a
Expand All @@ -166,7 +174,7 @@ def compile_restricted_function(
assert isinstance(function_ast, ast.FunctionDef)
function_ast.name = name

wrapper_ast.body[0].body = body_ast.body
wrapper_ast.body[0].body = body_ast
wrapper_ast = ast.fix_missing_locations(wrapper_ast)

result = _compile_restricted_mode(
Expand Down
66 changes: 66 additions & 0 deletions tests/test_compile_restricted_function.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
from types import FunctionType

from RestrictedPython import PrintCollector
Expand Down Expand Up @@ -233,3 +234,68 @@ def test_compile_restricted_function_invalid_syntax():
assert error_msg.startswith(
"Line 1: SyntaxError: cannot assign to literal here. Maybe "
)


def test_compile_restricted_function_pre_parse_exec():
p = ''
body = ast.parse("""
print("Hello World!")
return printed
""")
name = "hello_world"
global_symbols = []

result = compile_restricted_function(
p, # parameters
body,
name,
filename='<string>',
globalize=global_symbols
)

assert result.code is not None
assert result.errors == ()

safe_globals = {
'__name__': 'script',
'_getattr_': getattr,
'_print_': PrintCollector,
'__builtins__': safe_builtins,
}
safe_locals = {}
exec(result.code, safe_globals, safe_locals)
hello_world = safe_locals['hello_world']
assert type(hello_world) is FunctionType
assert hello_world() == 'Hello World!\n'


def test_compile_restricted_function_pre_parse_single():
p = ''
body = ast.parse("""
return "Hello World!"
""", mode="single")
name = "hello_world"
global_symbols = []

result = compile_restricted_function(
p, # parameters
body,
name,
filename='<string>',
globalize=global_symbols
)

assert result.code is not None
assert result.errors == ()

safe_globals = {
'__name__': 'script',
'_getattr_': getattr,
'_print_': PrintCollector,
'__builtins__': safe_builtins,
}
safe_locals = {}
exec(result.code, safe_globals, safe_locals)
hello_world = safe_locals['hello_world']
assert type(hello_world) is FunctionType
assert hello_world() == 'Hello World!'
Loading