From 246ad20411874c0a416e8ee577ede498e15290e5 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Wed, 18 Oct 2017 17:30:13 +0300 Subject: [PATCH 1/3] bpo-31778: Make ast.literal_eval() more strict. Addition and subtraction of arbitrary numbers no longer allowed. --- Lib/ast.py | 36 +++++++++++-------- Lib/test/test_ast.py | 29 +++++++++++++-- .../2017-10-18-17-29-30.bpo-31778.B6vAkP.rst | 2 ++ 3 files changed, 49 insertions(+), 18 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2017-10-18-17-29-30.bpo-31778.B6vAkP.rst diff --git a/Lib/ast.py b/Lib/ast.py index 070c2bee7f9dee0..2ecb03f38bc0d00 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -35,8 +35,6 @@ def parse(source, filename='', mode='exec'): return compile(source, filename, mode, PyCF_ONLY_AST) -_NUM_TYPES = (int, float, complex) - def literal_eval(node_or_string): """ Safely evaluate an expression node or a string containing a Python @@ -48,6 +46,21 @@ def literal_eval(node_or_string): node_or_string = parse(node_or_string, mode='eval') if isinstance(node_or_string, Expression): node_or_string = node_or_string.body + def _convert_num(node): + if isinstance(node, Constant): + if isinstance(node.value, (int, float, complex)): + return node.value + elif isinstance(node, Num): + return node.n + raise ValueError('malformed node or string: ' + repr(node)) + def _convert_signed_num(node): + if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)): + operand = _convert_num(node.operand) + if isinstance(node.op, UAdd): + return + operand + else: + return - operand + return _convert_num(node) def _convert(node): if isinstance(node, Constant): return node.value @@ -62,26 +75,19 @@ def _convert(node): elif isinstance(node, Set): return set(map(_convert, node.elts)) elif isinstance(node, Dict): - return dict((_convert(k), _convert(v)) for k, v - in zip(node.keys, node.values)) + return dict(zip(map(_convert, node.keys), + map(_convert, node.values))) elif isinstance(node, NameConstant): return node.value - elif isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)): - operand = _convert(node.operand) - if isinstance(operand, _NUM_TYPES): - if isinstance(node.op, UAdd): - return + operand - else: - return - operand elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)): - left = _convert(node.left) - right = _convert(node.right) - if isinstance(left, _NUM_TYPES) and isinstance(right, _NUM_TYPES): + left = _convert_signed_num(node.left) + right = _convert_num(node.right) + if isinstance(left, (int, float)) and isinstance(right, complex): if isinstance(node.op, Add): return left + right else: return left - right - raise ValueError('malformed node or string: ' + repr(node)) + return _convert_signed_num(node) return _convert(node_or_string) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index aa53503e3b5d8a4..f8b86f92f943bf3 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -551,9 +551,32 @@ def test_literal_eval(self): self.assertEqual(ast.literal_eval('{1, 2, 3}'), {1, 2, 3}) self.assertEqual(ast.literal_eval('b"hi"'), b"hi") self.assertRaises(ValueError, ast.literal_eval, 'foo()') + self.assertEqual(ast.literal_eval('6'), 6) + self.assertEqual(ast.literal_eval('+6'), 6) self.assertEqual(ast.literal_eval('-6'), -6) - self.assertEqual(ast.literal_eval('-6j+3'), 3-6j) self.assertEqual(ast.literal_eval('3.25'), 3.25) + self.assertEqual(ast.literal_eval('+3.25'), 3.25) + self.assertEqual(ast.literal_eval('-3.25'), -3.25) + self.assertEqual(ast.literal_eval('6j'), 6j) + self.assertEqual(ast.literal_eval('-6j'), -6j) + self.assertEqual(ast.literal_eval('6.75j'), 6.75j) + self.assertEqual(ast.literal_eval('-6.75j'), -6.75j) + self.assertEqual(repr(ast.literal_eval('-0.0')), '-0.0') + self.assertEqual(ast.literal_eval('3+6j'), 3+6j) + self.assertEqual(ast.literal_eval('-3+6j'), -3+6j) + self.assertEqual(ast.literal_eval('3-6j'), 3-6j) + self.assertEqual(ast.literal_eval('-3-6j'), -3-6j) + self.assertEqual(ast.literal_eval('3.25+6.75j'), 3.25+6.75j) + self.assertEqual(ast.literal_eval('-3.25+6.75j'), -3.25+6.75j) + self.assertEqual(ast.literal_eval('3.25-6.75j'), 3.25-6.75j) + self.assertEqual(ast.literal_eval('-3.25-6.75j'), -3.25-6.75j) + self.assertEqual(ast.literal_eval('(3+6j)'), 3+6j) + self.assertRaises(ValueError, ast.literal_eval, '++6') + self.assertRaises(ValueError, ast.literal_eval, '+True') + self.assertRaises(ValueError, ast.literal_eval, '2+3') + self.assertRaises(ValueError, ast.literal_eval, '-6j+3') + self.assertRaises(ValueError, ast.literal_eval, '3+-6j') + self.assertRaises(ValueError, ast.literal_eval, '-(3+6j)') def test_literal_eval_issue4907(self): self.assertEqual(ast.literal_eval('2j'), 2j) @@ -1077,11 +1100,11 @@ def test_literal_eval(self): ast.copy_location(new_left, binop.left) binop.left = new_left - new_right = ast.Constant(value=20) + new_right = ast.Constant(value=20j) ast.copy_location(new_right, binop.right) binop.right = new_right - self.assertEqual(ast.literal_eval(binop), 30) + self.assertEqual(ast.literal_eval(binop), 10+20j) def main(): diff --git a/Misc/NEWS.d/next/Library/2017-10-18-17-29-30.bpo-31778.B6vAkP.rst b/Misc/NEWS.d/next/Library/2017-10-18-17-29-30.bpo-31778.B6vAkP.rst new file mode 100644 index 000000000000000..452ad6e4bd2afb0 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2017-10-18-17-29-30.bpo-31778.B6vAkP.rst @@ -0,0 +1,2 @@ +ast.literal_eval() is now more strict. Addition and subtraction of +arbitrary numbers no longer allowed. From 1a0cd38fa5b8527a467c96f4c22a7343bcf4e28d Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sun, 22 Oct 2017 13:30:02 +0300 Subject: [PATCH 2/3] Fix test_inspect. --- Lib/test/test_inspect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py index 819fcc585376aad..1496585bea5b232 100644 --- a/Lib/test/test_inspect.py +++ b/Lib/test/test_inspect.py @@ -2053,7 +2053,7 @@ def p(name): return signature.parameters[name].default self.assertEqual(p('f'), False) self.assertEqual(p('local'), 3) self.assertEqual(p('sys'), sys.maxsize) - self.assertEqual(p('exp'), sys.maxsize - 1) + self.assertNotIn('exp', signature.parameters) test_callable(object) From da48ee4ca707e431adb4848fc29381c0991d7740 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Thu, 9 Nov 2017 23:39:18 +0200 Subject: [PATCH 3/3] Reorganize tests. --- Lib/test/test_ast.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index f8b86f92f943bf3..67f363ad31f39ab 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -557,11 +557,17 @@ def test_literal_eval(self): self.assertEqual(ast.literal_eval('3.25'), 3.25) self.assertEqual(ast.literal_eval('+3.25'), 3.25) self.assertEqual(ast.literal_eval('-3.25'), -3.25) + self.assertEqual(repr(ast.literal_eval('-0.0')), '-0.0') + self.assertRaises(ValueError, ast.literal_eval, '++6') + self.assertRaises(ValueError, ast.literal_eval, '+True') + self.assertRaises(ValueError, ast.literal_eval, '2+3') + + def test_literal_eval_complex(self): + # Issue #4907 self.assertEqual(ast.literal_eval('6j'), 6j) self.assertEqual(ast.literal_eval('-6j'), -6j) self.assertEqual(ast.literal_eval('6.75j'), 6.75j) self.assertEqual(ast.literal_eval('-6.75j'), -6.75j) - self.assertEqual(repr(ast.literal_eval('-0.0')), '-0.0') self.assertEqual(ast.literal_eval('3+6j'), 3+6j) self.assertEqual(ast.literal_eval('-3+6j'), -3+6j) self.assertEqual(ast.literal_eval('3-6j'), 3-6j) @@ -571,18 +577,12 @@ def test_literal_eval(self): self.assertEqual(ast.literal_eval('3.25-6.75j'), 3.25-6.75j) self.assertEqual(ast.literal_eval('-3.25-6.75j'), -3.25-6.75j) self.assertEqual(ast.literal_eval('(3+6j)'), 3+6j) - self.assertRaises(ValueError, ast.literal_eval, '++6') - self.assertRaises(ValueError, ast.literal_eval, '+True') - self.assertRaises(ValueError, ast.literal_eval, '2+3') self.assertRaises(ValueError, ast.literal_eval, '-6j+3') + self.assertRaises(ValueError, ast.literal_eval, '-6j+3j') self.assertRaises(ValueError, ast.literal_eval, '3+-6j') + self.assertRaises(ValueError, ast.literal_eval, '3+(0+6j)') self.assertRaises(ValueError, ast.literal_eval, '-(3+6j)') - def test_literal_eval_issue4907(self): - self.assertEqual(ast.literal_eval('2j'), 2j) - self.assertEqual(ast.literal_eval('10 + 2j'), 10 + 2j) - self.assertEqual(ast.literal_eval('1.5 - 2j'), 1.5 - 2j) - def test_bad_integer(self): # issue13436: Bad error message with invalid numeric values body = [ast.ImportFrom(module='time',