You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					222 lines
				
				6.7 KiB
			
		
		
			
		
	
	
					222 lines
				
				6.7 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import io
							 | 
						||
| 
								 | 
							
								import socket
							 | 
						||
| 
								 | 
							
								import ssl
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from ..exceptions import ProxySchemeUnsupported
							 | 
						||
| 
								 | 
							
								from ..packages import six
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								SSL_BLOCKSIZE = 16384
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SSLTransport:
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    The SSLTransport wraps an existing socket and establishes an SSL connection.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Contrary to Python's implementation of SSLSocket, it allows you to chain
							 | 
						||
| 
								 | 
							
								    multiple TLS connections together. It's particularly useful if you need to
							 | 
						||
| 
								 | 
							
								    implement TLS within TLS.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    The class supports most of the socket API operations.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @staticmethod
							 | 
						||
| 
								 | 
							
								    def _validate_ssl_context_for_tls_in_tls(ssl_context):
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
							 | 
						||
| 
								 | 
							
								        for TLS in TLS.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        The only requirement is that the ssl_context provides the 'wrap_bio'
							 | 
						||
| 
								 | 
							
								        methods.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not hasattr(ssl_context, "wrap_bio"):
							 | 
						||
| 
								 | 
							
								            if six.PY2:
							 | 
						||
| 
								 | 
							
								                raise ProxySchemeUnsupported(
							 | 
						||
| 
								 | 
							
								                    "TLS in TLS requires SSLContext.wrap_bio() which isn't "
							 | 
						||
| 
								 | 
							
								                    "supported on Python 2"
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                raise ProxySchemeUnsupported(
							 | 
						||
| 
								 | 
							
								                    "TLS in TLS requires SSLContext.wrap_bio() which isn't "
							 | 
						||
| 
								 | 
							
								                    "available on non-native SSLContext"
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Create an SSLTransport around socket using the provided ssl_context.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        self.incoming = ssl.MemoryBIO()
							 | 
						||
| 
								 | 
							
								        self.outgoing = ssl.MemoryBIO()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.suppress_ragged_eofs = suppress_ragged_eofs
							 | 
						||
| 
								 | 
							
								        self.socket = socket
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.sslobj = ssl_context.wrap_bio(
							 | 
						||
| 
								 | 
							
								            self.incoming, self.outgoing, server_hostname=server_hostname
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # Perform initial handshake.
							 | 
						||
| 
								 | 
							
								        self._ssl_io_loop(self.sslobj.do_handshake)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __enter__(self):
							 | 
						||
| 
								 | 
							
								        return self
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __exit__(self, *_):
							 | 
						||
| 
								 | 
							
								        self.close()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def fileno(self):
							 | 
						||
| 
								 | 
							
								        return self.socket.fileno()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def read(self, len=1024, buffer=None):
							 | 
						||
| 
								 | 
							
								        return self._wrap_ssl_read(len, buffer)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def recv(self, len=1024, flags=0):
							 | 
						||
| 
								 | 
							
								        if flags != 0:
							 | 
						||
| 
								 | 
							
								            raise ValueError("non-zero flags not allowed in calls to recv")
							 | 
						||
| 
								 | 
							
								        return self._wrap_ssl_read(len)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def recv_into(self, buffer, nbytes=None, flags=0):
							 | 
						||
| 
								 | 
							
								        if flags != 0:
							 | 
						||
| 
								 | 
							
								            raise ValueError("non-zero flags not allowed in calls to recv_into")
							 | 
						||
| 
								 | 
							
								        if buffer and (nbytes is None):
							 | 
						||
| 
								 | 
							
								            nbytes = len(buffer)
							 | 
						||
| 
								 | 
							
								        elif nbytes is None:
							 | 
						||
| 
								 | 
							
								            nbytes = 1024
							 | 
						||
| 
								 | 
							
								        return self.read(nbytes, buffer)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def sendall(self, data, flags=0):
							 | 
						||
| 
								 | 
							
								        if flags != 0:
							 | 
						||
| 
								 | 
							
								            raise ValueError("non-zero flags not allowed in calls to sendall")
							 | 
						||
| 
								 | 
							
								        count = 0
							 | 
						||
| 
								 | 
							
								        with memoryview(data) as view, view.cast("B") as byte_view:
							 | 
						||
| 
								 | 
							
								            amount = len(byte_view)
							 | 
						||
| 
								 | 
							
								            while count < amount:
							 | 
						||
| 
								 | 
							
								                v = self.send(byte_view[count:])
							 | 
						||
| 
								 | 
							
								                count += v
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def send(self, data, flags=0):
							 | 
						||
| 
								 | 
							
								        if flags != 0:
							 | 
						||
| 
								 | 
							
								            raise ValueError("non-zero flags not allowed in calls to send")
							 | 
						||
| 
								 | 
							
								        response = self._ssl_io_loop(self.sslobj.write, data)
							 | 
						||
| 
								 | 
							
								        return response
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def makefile(
							 | 
						||
| 
								 | 
							
								        self, mode="r", buffering=None, encoding=None, errors=None, newline=None
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Python's httpclient uses makefile and buffered io when reading HTTP
							 | 
						||
| 
								 | 
							
								        messages and we need to support it.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        This is unfortunately a copy and paste of socket.py makefile with small
							 | 
						||
| 
								 | 
							
								        changes to point to the socket directly.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        if not set(mode) <= {"r", "w", "b"}:
							 | 
						||
| 
								 | 
							
								            raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        writing = "w" in mode
							 | 
						||
| 
								 | 
							
								        reading = "r" in mode or not writing
							 | 
						||
| 
								 | 
							
								        assert reading or writing
							 | 
						||
| 
								 | 
							
								        binary = "b" in mode
							 | 
						||
| 
								 | 
							
								        rawmode = ""
							 | 
						||
| 
								 | 
							
								        if reading:
							 | 
						||
| 
								 | 
							
								            rawmode += "r"
							 | 
						||
| 
								 | 
							
								        if writing:
							 | 
						||
| 
								 | 
							
								            rawmode += "w"
							 | 
						||
| 
								 | 
							
								        raw = socket.SocketIO(self, rawmode)
							 | 
						||
| 
								 | 
							
								        self.socket._io_refs += 1
							 | 
						||
| 
								 | 
							
								        if buffering is None:
							 | 
						||
| 
								 | 
							
								            buffering = -1
							 | 
						||
| 
								 | 
							
								        if buffering < 0:
							 | 
						||
| 
								 | 
							
								            buffering = io.DEFAULT_BUFFER_SIZE
							 | 
						||
| 
								 | 
							
								        if buffering == 0:
							 | 
						||
| 
								 | 
							
								            if not binary:
							 | 
						||
| 
								 | 
							
								                raise ValueError("unbuffered streams must be binary")
							 | 
						||
| 
								 | 
							
								            return raw
							 | 
						||
| 
								 | 
							
								        if reading and writing:
							 | 
						||
| 
								 | 
							
								            buffer = io.BufferedRWPair(raw, raw, buffering)
							 | 
						||
| 
								 | 
							
								        elif reading:
							 | 
						||
| 
								 | 
							
								            buffer = io.BufferedReader(raw, buffering)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            assert writing
							 | 
						||
| 
								 | 
							
								            buffer = io.BufferedWriter(raw, buffering)
							 | 
						||
| 
								 | 
							
								        if binary:
							 | 
						||
| 
								 | 
							
								            return buffer
							 | 
						||
| 
								 | 
							
								        text = io.TextIOWrapper(buffer, encoding, errors, newline)
							 | 
						||
| 
								 | 
							
								        text.mode = mode
							 | 
						||
| 
								 | 
							
								        return text
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def unwrap(self):
							 | 
						||
| 
								 | 
							
								        self._ssl_io_loop(self.sslobj.unwrap)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def close(self):
							 | 
						||
| 
								 | 
							
								        self.socket.close()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def getpeercert(self, binary_form=False):
							 | 
						||
| 
								 | 
							
								        return self.sslobj.getpeercert(binary_form)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def version(self):
							 | 
						||
| 
								 | 
							
								        return self.sslobj.version()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def cipher(self):
							 | 
						||
| 
								 | 
							
								        return self.sslobj.cipher()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def selected_alpn_protocol(self):
							 | 
						||
| 
								 | 
							
								        return self.sslobj.selected_alpn_protocol()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def selected_npn_protocol(self):
							 | 
						||
| 
								 | 
							
								        return self.sslobj.selected_npn_protocol()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def shared_ciphers(self):
							 | 
						||
| 
								 | 
							
								        return self.sslobj.shared_ciphers()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def compression(self):
							 | 
						||
| 
								 | 
							
								        return self.sslobj.compression()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def settimeout(self, value):
							 | 
						||
| 
								 | 
							
								        self.socket.settimeout(value)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def gettimeout(self):
							 | 
						||
| 
								 | 
							
								        return self.socket.gettimeout()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _decref_socketios(self):
							 | 
						||
| 
								 | 
							
								        self.socket._decref_socketios()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _wrap_ssl_read(self, len, buffer=None):
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            return self._ssl_io_loop(self.sslobj.read, len, buffer)
							 | 
						||
| 
								 | 
							
								        except ssl.SSLError as e:
							 | 
						||
| 
								 | 
							
								            if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
							 | 
						||
| 
								 | 
							
								                return 0  # eof, return 0.
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                raise
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _ssl_io_loop(self, func, *args):
							 | 
						||
| 
								 | 
							
								        """Performs an I/O loop between incoming/outgoing and the socket."""
							 | 
						||
| 
								 | 
							
								        should_loop = True
							 | 
						||
| 
								 | 
							
								        ret = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        while should_loop:
							 | 
						||
| 
								 | 
							
								            errno = None
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                ret = func(*args)
							 | 
						||
| 
								 | 
							
								            except ssl.SSLError as e:
							 | 
						||
| 
								 | 
							
								                if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
							 | 
						||
| 
								 | 
							
								                    # WANT_READ, and WANT_WRITE are expected, others are not.
							 | 
						||
| 
								 | 
							
								                    raise e
							 | 
						||
| 
								 | 
							
								                errno = e.errno
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            buf = self.outgoing.read()
							 | 
						||
| 
								 | 
							
								            self.socket.sendall(buf)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if errno is None:
							 | 
						||
| 
								 | 
							
								                should_loop = False
							 | 
						||
| 
								 | 
							
								            elif errno == ssl.SSL_ERROR_WANT_READ:
							 | 
						||
| 
								 | 
							
								                buf = self.socket.recv(SSL_BLOCKSIZE)
							 | 
						||
| 
								 | 
							
								                if buf:
							 | 
						||
| 
								 | 
							
								                    self.incoming.write(buf)
							 | 
						||
| 
								 | 
							
								                else:
							 | 
						||
| 
								 | 
							
								                    self.incoming.write_eof()
							 | 
						||
| 
								 | 
							
								        return ret
							 |