@@ -355,13 +355,20 @@ def __new__(cls, protocol=PROTOCOL_TLS, *args, **kwargs):
355355 self = _SSLContext .__new__ (cls , protocol )
356356 return self
357357
358- def __init__ (self , protocol = PROTOCOL_TLS ):
359- self .protocol = protocol
358+ def _encode_hostname (self , hostname ):
359+ if hostname is None :
360+ return None
361+ elif isinstance (hostname , str ):
362+ return hostname .encode ('idna' ).decode ('ascii' )
363+ else :
364+ return hostname .decode ('ascii' )
360365
361366 def wrap_socket (self , sock , server_side = False ,
362367 do_handshake_on_connect = True ,
363368 suppress_ragged_eofs = True ,
364369 server_hostname = None , session = None ):
370+ # SSLSocket class handles server_hostname encoding before it calls
371+ # ctx._wrap_socket()
365372 return self .sslsocket_class (
366373 sock = sock ,
367374 server_side = server_side ,
@@ -374,8 +381,12 @@ def wrap_socket(self, sock, server_side=False,
374381
375382 def wrap_bio (self , incoming , outgoing , server_side = False ,
376383 server_hostname = None , session = None ):
377- sslobj = self ._wrap_bio (incoming , outgoing , server_side = server_side ,
378- server_hostname = server_hostname )
384+ # Need to encode server_hostname here because _wrap_bio() can only
385+ # handle ASCII str.
386+ sslobj = self ._wrap_bio (
387+ incoming , outgoing , server_side = server_side ,
388+ server_hostname = self ._encode_hostname (server_hostname )
389+ )
379390 return self .sslobject_class (sslobj , session = session )
380391
381392 def set_npn_protocols (self , npn_protocols ):
@@ -389,6 +400,20 @@ def set_npn_protocols(self, npn_protocols):
389400
390401 self ._set_npn_protocols (protos )
391402
403+ def set_servername_callback (self , server_name_callback ):
404+ if server_name_callback is None :
405+ self .sni_callback = None
406+ else :
407+ if not hasattr (server_name_callback , '__call__' ):
408+ raise TypeError ("not a callable object" )
409+
410+ def shim_cb (sslobj , servername , sslctx ):
411+ if servername is not None :
412+ servername = servername .encode ("ascii" ).decode ("idna" )
413+ return server_name_callback (sslobj , servername , sslctx )
414+
415+ self .sni_callback = shim_cb
416+
392417 def set_alpn_protocols (self , alpn_protocols ):
393418 protos = bytearray ()
394419 for protocol in alpn_protocols :
@@ -447,6 +472,10 @@ def hostname_checks_common_name(self, value):
447472 def hostname_checks_common_name (self ):
448473 return True
449474
475+ @property
476+ def protocol (self ):
477+ return _SSLMethod (super ().protocol )
478+
450479 @property
451480 def verify_flags (self ):
452481 return VerifyFlags (super ().verify_flags )
@@ -749,7 +778,7 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
749778 raise ValueError ("check_hostname requires server_hostname" )
750779 self ._session = _session
751780 self .server_side = server_side
752- self .server_hostname = server_hostname
781+ self .server_hostname = self . _context . _encode_hostname ( server_hostname )
753782 self .do_handshake_on_connect = do_handshake_on_connect
754783 self .suppress_ragged_eofs = suppress_ragged_eofs
755784 if sock is not None :
@@ -781,7 +810,7 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
781810 # create the SSL object
782811 try :
783812 sslobj = self ._context ._wrap_socket (self , server_side ,
784- server_hostname )
813+ self . server_hostname )
785814 self ._sslobj = SSLObject (sslobj , owner = self ,
786815 session = self ._session )
787816 if do_handshake_on_connect :
0 commit comments