diff --git a/Lib/test/test_sax.py b/Lib/test/test_sax.py index a4228dc654ddca..87a8658672b83e 100644 --- a/Lib/test/test_sax.py +++ b/Lib/test/test_sax.py @@ -2,7 +2,8 @@ # $Id$ from xml.sax import make_parser, ContentHandler, \ - SAXException, SAXReaderNotAvailable, SAXParseException + SAXException, SAXReaderNotAvailable, SAXParseException, \ + saxutils try: make_parser() except SAXReaderNotAvailable: @@ -173,6 +174,21 @@ def test_parse_InputSource(self): input.setEncoding('iso-8859-1') self.check_parse(input) + def test_parse_close_source(self): + builtin_open = open + non_local = {'fileobj': None} + + def mock_open(*args): + fileobj = builtin_open(*args) + non_local['fileobj'] = fileobj + return fileobj + + with support.swap_attr(saxutils, 'open', mock_open): + make_xml_file(self.data, 'iso-8859-1', None) + with self.assertRaises(SAXException): + self.check_parse(TESTFN) + self.assertTrue(non_local['fileobj'].closed) + def check_parseString(self, s): from xml.sax import parseString result = StringIO() diff --git a/Lib/xml/sax/expatreader.py b/Lib/xml/sax/expatreader.py index 21c9db91e9d465..bae663bdd99124 100644 --- a/Lib/xml/sax/expatreader.py +++ b/Lib/xml/sax/expatreader.py @@ -105,9 +105,16 @@ def parse(self, source): source = saxutils.prepare_input_source(source) self._source = source - self.reset() - self._cont_handler.setDocumentLocator(ExpatLocator(self)) - xmlreader.IncrementalParser.parse(self, source) + try: + self.reset() + self._cont_handler.setDocumentLocator(ExpatLocator(self)) + xmlreader.IncrementalParser.parse(self, source) + except: + # bpo-30264: Close the source on error to not leak resources: + # xml.sax.parse() doesn't give access to the underlying parser + # to the caller + self._close_source() + raise def prepareParser(self, source): if source.getSystemId() is not None: @@ -216,6 +223,17 @@ def feed(self, data, isFinal = 0): # FIXME: when to invoke error()? self._err_handler.fatalError(exc) + def _close_source(self): + source = self._source + try: + file = source.getCharacterStream() + if file is not None: + file.close() + finally: + file = source.getByteStream() + if file is not None: + file.close() + def close(self): if (self._entity_stack or self._parser is None or isinstance(self._parser, _ClosedParser)): @@ -235,6 +253,7 @@ def close(self): parser.ErrorColumnNumber = self._parser.ErrorColumnNumber parser.ErrorLineNumber = self._parser.ErrorLineNumber self._parser = parser + self._close_source() def _reset_cont_handler(self): self._parser.ProcessingInstructionHandler = \