Skip to content

Commit a953365

Browse files
committed
bpo-33062: Added SSL renegotiate and key_update
1 parent 17775ae commit a953365

6 files changed

Lines changed: 530 additions & 11 deletions

File tree

Lib/ssl.py

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@
149149
lambda name: name.startswith('CERT_'),
150150
source=_ssl)
151151

152+
_IntEnum._convert_(
153+
'KeyUpdateTypes', __name__,
154+
lambda name: name.startswith('KEY_UPDATE_'),
155+
source=_ssl)
156+
152157
PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS
153158
_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
154159

@@ -780,6 +785,23 @@ def version(self):
780785
def verify_client_post_handshake(self):
781786
return self._sslobj.verify_client_post_handshake()
782787

788+
def key_update(self, updatetype):
789+
self._sslobj.key_update(updatetype)
790+
791+
@property
792+
def key_update_type(self):
793+
return KeyUpdateTypes(self._sslobj.get_key_update_type())
794+
795+
def renegotiate(self, abbreviated=False):
796+
if abbreviated:
797+
self._sslobj.renegotiate_abbreviated()
798+
else:
799+
self._sslobj.renegotiate()
800+
801+
@property
802+
def renegotiate_pending(self):
803+
return self._sslobj.renegotiate_pending()
804+
783805

784806
class SSLSocket(socket):
785807
"""This class implements a subtype of socket.socket that wraps
@@ -1090,18 +1112,14 @@ def shutdown(self, how):
10901112
super().shutdown(how)
10911113

10921114
def unwrap(self):
1093-
if self._sslobj:
1094-
s = self._sslobj.shutdown()
1095-
self._sslobj = None
1096-
return s
1097-
else:
1098-
raise ValueError("No SSL wrapper around " + str(self))
1115+
self._ensure_wrapper()
1116+
s = self._sslobj.shutdown()
1117+
self._sslobj = None
1118+
return s
10991119

11001120
def verify_client_post_handshake(self):
1101-
if self._sslobj:
1102-
return self._sslobj.verify_client_post_handshake()
1103-
else:
1104-
raise ValueError("No SSL wrapper around " + str(self))
1121+
self._ensure_wrapper()
1122+
return self._sslobj.verify_client_post_handshake()
11051123

11061124
def _real_close(self):
11071125
self._sslobj = None
@@ -1190,6 +1208,76 @@ def version(self):
11901208
else:
11911209
return None
11921210

1211+
def _ensure_wrapper(self):
1212+
if not self._sslobj:
1213+
raise ValueError("No SSL wrapper around " + str(self))
1214+
1215+
def key_update(self, updatetype):
1216+
self._ensure_wrapper()
1217+
self._sslobj.key_update(updatetype)
1218+
1219+
@property
1220+
def key_update_type(self):
1221+
self._ensure_wrapper()
1222+
return KeyUpdateTypes(self._sslobj.get_key_update_type())
1223+
1224+
def renegotiate(self, abbreviated=False):
1225+
self._ensure_wrapper()
1226+
if abbreviated:
1227+
self._sslobj.renegotiate_abbreviated()
1228+
else:
1229+
self._sslobj.renegotiate()
1230+
1231+
@property
1232+
def renegotiate_pending(self):
1233+
self._ensure_wrapper()
1234+
return self._sslobj.renegotiate_pending()
1235+
1236+
1237+
for name, docstr in (
1238+
('key_update', """\
1239+
Schedule an update of the keys for the current TLS connection.
1240+
1241+
If the updatetype parameter is set to KEY_UPDATE_NOT_REQUESTED then the
1242+
sending keys for this connection will be updated and the peer will be
1243+
informed of the change. If the updatetype parameter is set to
1244+
KEY_UPDATE_REQUESTED then the sending keys for this connection will be
1245+
updated and the peer will be informed of the change along with a
1246+
request for the peer to additionally update its sending keys. It is an
1247+
error if updatetype is set to KEY_UPDATE_NONE.
1248+
1249+
key_update() must only be called after the initial handshake has been
1250+
completed and TLSv1.3 has been negotiated. The key update will not take
1251+
place until the next time an IO operation such as read() or write()
1252+
takes place on the connection. Alternatively do_handshake() can be
1253+
called to force the update to take place immediately.
1254+
1255+
Raises NotImplementedError if the TLS implementation doesn't support
1256+
TLS 1.3.)
1257+
1258+
:param updatetype: KeyUpdateTypes
1259+
"""),
1260+
('key_update_type', """\
1261+
Determine whether a key update operation has been scheduled but
1262+
not yet performed.
1263+
1264+
The type of the pending key update operation will be returned if there
1265+
is one, or KEY_UPDATE_NONE otherwise.
1266+
1267+
Raises NotImplementedError if the TLS implementation doesn't support
1268+
TLS 1.3.
1269+
"""),
1270+
('renegotiate', """\
1271+
Start the SSL/TLS renegotiation, requires TLS <= 1.2."""),
1272+
('renegotiate_pending', """\
1273+
Return True if a renegotiation or renegotiation request has been
1274+
scheduled but not yet acted on, or False otherwise."""),
1275+
):
1276+
for cls in (SSLObject, SSLSocket):
1277+
getattr(cls, name).__doc__ = docstr
1278+
1279+
del name, docstr, cls
1280+
11931281

11941282
# Python does not support forward declaration of types.
11951283
SSLContext.sslsocket_class = SSLSocket

Lib/test/test_ssl.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,39 @@ def test_bad_server_hostname(self):
16451645
ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
16461646
server_hostname="example.org\x00evil.com")
16471647

1648+
@unittest.skipUnless(ssl.HAS_TLSv1_3,
1649+
"test requires TLSv1.3 enabled OpenSSL")
1650+
def test_invalid_key_update_type_and_still_in_init(self):
1651+
b1 = ssl.MemoryBIO()
1652+
b2 = ssl.MemoryBIO()
1653+
ctx = ssl.SSLContext()
1654+
ctx.load_cert_chain(SIGNED_CERTFILE)
1655+
server = ctx.wrap_bio(b1, b2, server_side=True)
1656+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1657+
ctx.options |= (
1658+
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
1659+
)
1660+
sslobj = ctx.wrap_bio(b2, b1, server_side=False)
1661+
1662+
handshaking = True
1663+
while handshaking:
1664+
try:
1665+
sslobj.do_handshake()
1666+
handshaking = False
1667+
except ssl.SSLWantReadError:
1668+
handshaking = True
1669+
1670+
try:
1671+
server.do_handshake()
1672+
except ssl.SSLWantReadError:
1673+
handshaking = True
1674+
1675+
with self.assertRaisesRegex(ssl.SSLError, 'invalid key update type'):
1676+
sslobj.key_update(ssl.KEY_UPDATE_NONE)
1677+
sslobj.key_update(ssl.KEY_UPDATE_REQUESTED)
1678+
with self.assertRaisesRegex(ssl.SSLError, 'still in init'):
1679+
sslobj.key_update(ssl.KEY_UPDATE_REQUESTED)
1680+
16481681

16491682
class MemoryBIOTests(unittest.TestCase):
16501683

@@ -2076,6 +2109,91 @@ def test_bio_read_write_data(self):
20762109
self.assertEqual(buf, b'foo\n')
20772110
self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
20782111

2112+
def test_bio_renegotiation(self):
2113+
sock = socket.socket(socket.AF_INET)
2114+
self.addCleanup(sock.close)
2115+
sock.connect(self.server_addr)
2116+
incoming = ssl.MemoryBIO()
2117+
outgoing = ssl.MemoryBIO()
2118+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2119+
ctx.verify_mode = ssl.CERT_NONE
2120+
ctx.options |= ssl.OP_NO_TLSv1_3
2121+
sslobj = ctx.wrap_bio(incoming, outgoing, False)
2122+
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2123+
2124+
self.assertEqual(outgoing.pending, 0)
2125+
sslobj.renegotiate()
2126+
self.assertEqual(outgoing.pending, 0)
2127+
self.assertTrue(sslobj.renegotiate_pending)
2128+
req = b'FOO\n'
2129+
self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2130+
self.assertFalse(sslobj.renegotiate_pending)
2131+
buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2132+
self.assertEqual(buf, b'foo\n')
2133+
2134+
self.assertEqual(outgoing.pending, 0)
2135+
sslobj.renegotiate(abbreviated=True)
2136+
self.assertEqual(outgoing.pending, 0)
2137+
self.assertTrue(sslobj.renegotiate_pending)
2138+
req = b'BAR\n'
2139+
self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2140+
self.assertFalse(sslobj.renegotiate_pending)
2141+
buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2142+
self.assertEqual(buf, b'bar\n')
2143+
2144+
if IS_OPENSSL_1_1_1 and ssl.HAS_TLSv1_3:
2145+
with self.assertRaises(ssl.SSLError,
2146+
msg='wrong ssl version'):
2147+
sslobj.key_update(ssl.KEY_UPDATE_NOT_REQUESTED)
2148+
with self.assertRaises(ssl.SSLError,
2149+
msg='wrong ssl version'):
2150+
sslobj.key_update(ssl.KEY_UPDATE_REQUESTED)
2151+
2152+
self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2153+
2154+
@unittest.skipUnless(ssl.HAS_TLSv1_3,
2155+
"test requires TLSv1.3 enabled OpenSSL")
2156+
def test_bio_key_update(self):
2157+
sock = socket.socket(socket.AF_INET)
2158+
self.addCleanup(sock.close)
2159+
sock.connect(self.server_addr)
2160+
incoming = ssl.MemoryBIO()
2161+
outgoing = ssl.MemoryBIO()
2162+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2163+
ctx.verify_mode = ssl.CERT_NONE
2164+
ctx.options |= (
2165+
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
2166+
)
2167+
sslobj = ctx.wrap_bio(incoming, outgoing, False)
2168+
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2169+
2170+
self.assertEqual(outgoing.pending, 0)
2171+
sslobj.key_update(ssl.KEY_UPDATE_REQUESTED)
2172+
self.assertEqual(outgoing.pending, 0)
2173+
self.assertEqual(sslobj.key_update_type, ssl.KEY_UPDATE_REQUESTED)
2174+
req = b'FOO\n'
2175+
self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2176+
self.assertEqual(sslobj.key_update_type, ssl.KEY_UPDATE_NONE)
2177+
buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2178+
self.assertEqual(buf, b'foo\n')
2179+
2180+
self.assertEqual(outgoing.pending, 0)
2181+
sslobj.key_update(ssl.KEY_UPDATE_NOT_REQUESTED)
2182+
self.assertEqual(outgoing.pending, 0)
2183+
self.assertEqual(sslobj.key_update_type, ssl.KEY_UPDATE_NOT_REQUESTED)
2184+
req = b'BAR\n'
2185+
self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2186+
self.assertEqual(sslobj.key_update_type, ssl.KEY_UPDATE_NONE)
2187+
buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2188+
self.assertEqual(buf, b'bar\n')
2189+
2190+
with self.assertRaises(ssl.SSLError, msg='wrong ssl version'):
2191+
sslobj.renegotiate()
2192+
with self.assertRaises(ssl.SSLError, msg='wrong ssl version'):
2193+
sslobj.renegotiate(abbreviated=True)
2194+
2195+
self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2196+
20792197

20802198
class NetworkedTests(unittest.TestCase):
20812199

@@ -4164,6 +4282,78 @@ def test_session_handling(self):
41644282
self.assertEqual(str(e.exception),
41654283
'Session refers to a different SSLContext.')
41664284

4285+
def test_renegotiation(self):
4286+
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
4287+
context.load_cert_chain(CERTFILE)
4288+
context.options |= ssl.OP_NO_TLSv1_3
4289+
with ThreadedEchoServer(context=context) as server:
4290+
with context.wrap_socket(socket.socket()) as s:
4291+
s.connect((HOST, server.port))
4292+
self.assertFalse(s.renegotiate_pending)
4293+
s.renegotiate()
4294+
self.assertTrue(s.renegotiate_pending)
4295+
s.send(b'HELLO')
4296+
self.assertEqual(s.recv(1024), b'hello')
4297+
self.assertFalse(s.renegotiate_pending)
4298+
s.send(b'WORLD')
4299+
self.assertEqual(s.recv(1024), b'world')
4300+
4301+
self.assertFalse(s.renegotiate_pending)
4302+
s.renegotiate(abbreviated=True)
4303+
self.assertTrue(s.renegotiate_pending)
4304+
s.send(b'HELLO')
4305+
self.assertEqual(s.recv(1024), b'hello')
4306+
self.assertFalse(s.renegotiate_pending)
4307+
s.send(b'WORLD')
4308+
self.assertEqual(s.recv(1024), b'world')
4309+
4310+
if IS_OPENSSL_1_1_1 and ssl.HAS_TLSv1_3:
4311+
with self.assertRaises(ssl.SSLError,
4312+
msg='wrong ssl version'):
4313+
s.key_update(ssl.KEY_UPDATE_NOT_REQUESTED)
4314+
with self.assertRaises(ssl.SSLError,
4315+
msg='wrong ssl version'):
4316+
s.key_update(ssl.KEY_UPDATE_REQUESTED)
4317+
4318+
@unittest.skipUnless(ssl.HAS_TLSv1_3,
4319+
"test requires TLSv1.3 enabled OpenSSL")
4320+
def test_key_update(self):
4321+
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
4322+
context.load_cert_chain(CERTFILE)
4323+
context.options |= (
4324+
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
4325+
)
4326+
with ThreadedEchoServer(context=context) as server:
4327+
with context.wrap_socket(socket.socket()) as s:
4328+
s.connect((HOST, server.port))
4329+
self.assertEqual(s.version(), 'TLSv1.3')
4330+
4331+
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
4332+
s.key_update(ssl.KEY_UPDATE_NOT_REQUESTED)
4333+
self.assertEqual(s.key_update_type,
4334+
ssl.KEY_UPDATE_NOT_REQUESTED)
4335+
s.send(b'HELLO')
4336+
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
4337+
self.assertEqual(s.recv(1024), b'hello')
4338+
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
4339+
s.send(b'WORLD')
4340+
self.assertEqual(s.recv(1024), b'world')
4341+
4342+
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
4343+
s.key_update(ssl.KEY_UPDATE_REQUESTED)
4344+
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_REQUESTED)
4345+
s.send(b'HELLO')
4346+
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
4347+
self.assertEqual(s.recv(1024), b'hello')
4348+
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
4349+
s.send(b'WORLD')
4350+
self.assertEqual(s.recv(1024), b'world')
4351+
4352+
with self.assertRaises(ssl.SSLError, msg='wrong ssl version'):
4353+
s.renegotiate()
4354+
with self.assertRaises(ssl.SSLError, msg='wrong ssl version'):
4355+
s.renegotiate(abbreviated=True)
4356+
41674357

41684358
@unittest.skipUnless(ssl.HAS_TLSv1_3, "Test needs TLS 1.3")
41694359
class TestPostHandshakeAuth(unittest.TestCase):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added ``renegotiate()`` and ``key_update()`` in :mod:`ssl`.

0 commit comments

Comments
 (0)