You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					282 lines
				
				12 KiB
			
		
		
			
		
	
	
					282 lines
				
				12 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import logging
							 | 
						||
| 
								 | 
							
								import re
							 | 
						||
| 
								 | 
							
								import ssl
							 | 
						||
| 
								 | 
							
								from dataclasses import dataclass
							 | 
						||
| 
								 | 
							
								from functools import wraps
							 | 
						||
| 
								 | 
							
								from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from .. import BrokenResourceError, EndOfStream, aclose_forcefully, get_cancelled_exc_class
							 | 
						||
| 
								 | 
							
								from .._core._typedattr import TypedAttributeSet, typed_attribute
							 | 
						||
| 
								 | 
							
								from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								T_Retval = TypeVar('T_Retval')
							 | 
						||
| 
								 | 
							
								_PCTRTT = Tuple[Tuple[str, str], ...]
							 | 
						||
| 
								 | 
							
								_PCTRTTT = Tuple[_PCTRTT, ...]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class TLSAttribute(TypedAttributeSet):
							 | 
						||
| 
								 | 
							
								    """Contains Transport Layer Security related attributes."""
							 | 
						||
| 
								 | 
							
								    #: the selected ALPN protocol
							 | 
						||
| 
								 | 
							
								    alpn_protocol: Optional[str] = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: the channel binding for type ``tls-unique``
							 | 
						||
| 
								 | 
							
								    channel_binding_tls_unique: bytes = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: the selected cipher
							 | 
						||
| 
								 | 
							
								    cipher: Tuple[str, str, int] = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` for more
							 | 
						||
| 
								 | 
							
								    #: information)
							 | 
						||
| 
								 | 
							
								    peer_certificate: Optional[Dict[str, Union[str, _PCTRTTT, _PCTRTT]]] = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: the peer certificate in binary form
							 | 
						||
| 
								 | 
							
								    peer_certificate_binary: Optional[bytes] = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: ``True`` if this is the server side of the connection
							 | 
						||
| 
								 | 
							
								    server_side: bool = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: ciphers shared between both ends of the TLS connection
							 | 
						||
| 
								 | 
							
								    shared_ciphers: List[Tuple[str, str, int]] = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: the :class:`~ssl.SSLObject` used for encryption
							 | 
						||
| 
								 | 
							
								    ssl_object: ssl.SSLObject = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: ``True`` if this stream does (and expects) a closing TLS handshake when the stream is being
							 | 
						||
| 
								 | 
							
								    #: closed
							 | 
						||
| 
								 | 
							
								    standard_compatible: bool = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: the TLS protocol version (e.g. ``TLSv1.2``)
							 | 
						||
| 
								 | 
							
								    tls_version: str = typed_attribute()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@dataclass(eq=False)
							 | 
						||
| 
								 | 
							
								class TLSStream(ByteStream):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    A stream wrapper that encrypts all sent data and decrypts received data.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    This class has no public initializer; use :meth:`wrap` instead.
							 | 
						||
| 
								 | 
							
								    All extra attributes from :class:`~TLSAttribute` are supported.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    :var AnyByteStream transport_stream: the wrapped stream
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    transport_stream: AnyByteStream
							 | 
						||
| 
								 | 
							
								    standard_compatible: bool
							 | 
						||
| 
								 | 
							
								    _ssl_object: ssl.SSLObject
							 | 
						||
| 
								 | 
							
								    _read_bio: ssl.MemoryBIO
							 | 
						||
| 
								 | 
							
								    _write_bio: ssl.MemoryBIO
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @classmethod
							 | 
						||
| 
								 | 
							
								    async def wrap(cls, transport_stream: AnyByteStream, *, server_side: Optional[bool] = None,
							 | 
						||
| 
								 | 
							
								                   hostname: Optional[str] = None, ssl_context: Optional[ssl.SSLContext] = None,
							 | 
						||
| 
								 | 
							
								                   standard_compatible: bool = True) -> 'TLSStream':
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Wrap an existing stream with Transport Layer Security.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        This performs a TLS handshake with the peer.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param transport_stream: a bytes-transporting stream to wrap
							 | 
						||
| 
								 | 
							
								        :param server_side: ``True`` if this is the server side of the connection, ``False`` if
							 | 
						||
| 
								 | 
							
								            this is the client side (if omitted, will be set to ``False`` if ``hostname`` has been
							 | 
						||
| 
								 | 
							
								            provided, ``False`` otherwise). Used only to create a default context when an explicit
							 | 
						||
| 
								 | 
							
								            context has not been provided.
							 | 
						||
| 
								 | 
							
								        :param hostname: host name of the peer (if host name checking is desired)
							 | 
						||
| 
								 | 
							
								        :param ssl_context: the SSLContext object to use (if not provided, a secure default will be
							 | 
						||
| 
								 | 
							
								            created)
							 | 
						||
| 
								 | 
							
								        :param standard_compatible: if ``False``, skip the closing handshake when closing the
							 | 
						||
| 
								 | 
							
								            connection, and don't raise an exception if the peer does the same
							 | 
						||
| 
								 | 
							
								        :raises ~ssl.SSLError: if the TLS handshake fails
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        if server_side is None:
							 | 
						||
| 
								 | 
							
								            server_side = not hostname
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not ssl_context:
							 | 
						||
| 
								 | 
							
								            purpose = ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
							 | 
						||
| 
								 | 
							
								            ssl_context = ssl.create_default_context(purpose)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            # Re-enable detection of unexpected EOFs if it was disabled by Python
							 | 
						||
| 
								 | 
							
								            if hasattr(ssl, 'OP_IGNORE_UNEXPECTED_EOF'):
							 | 
						||
| 
								 | 
							
								                ssl_context.options ^= ssl.OP_IGNORE_UNEXPECTED_EOF  # type: ignore[attr-defined]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        bio_in = ssl.MemoryBIO()
							 | 
						||
| 
								 | 
							
								        bio_out = ssl.MemoryBIO()
							 | 
						||
| 
								 | 
							
								        ssl_object = ssl_context.wrap_bio(bio_in, bio_out, server_side=server_side,
							 | 
						||
| 
								 | 
							
								                                          server_hostname=hostname)
							 | 
						||
| 
								 | 
							
								        wrapper = cls(transport_stream=transport_stream,
							 | 
						||
| 
								 | 
							
								                      standard_compatible=standard_compatible, _ssl_object=ssl_object,
							 | 
						||
| 
								 | 
							
								                      _read_bio=bio_in, _write_bio=bio_out)
							 | 
						||
| 
								 | 
							
								        await wrapper._call_sslobject_method(ssl_object.do_handshake)
							 | 
						||
| 
								 | 
							
								        return wrapper
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def _call_sslobject_method(
							 | 
						||
| 
								 | 
							
								        self, func: Callable[..., T_Retval], *args: object
							 | 
						||
| 
								 | 
							
								    ) -> T_Retval:
							 | 
						||
| 
								 | 
							
								        while True:
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                result = func(*args)
							 | 
						||
| 
								 | 
							
								            except ssl.SSLWantReadError:
							 | 
						||
| 
								 | 
							
								                try:
							 | 
						||
| 
								 | 
							
								                    # Flush any pending writes first
							 | 
						||
| 
								 | 
							
								                    if self._write_bio.pending:
							 | 
						||
| 
								 | 
							
								                        await self.transport_stream.send(self._write_bio.read())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                    data = await self.transport_stream.receive()
							 | 
						||
| 
								 | 
							
								                except EndOfStream:
							 | 
						||
| 
								 | 
							
								                    self._read_bio.write_eof()
							 | 
						||
| 
								 | 
							
								                except OSError as exc:
							 | 
						||
| 
								 | 
							
								                    self._read_bio.write_eof()
							 | 
						||
| 
								 | 
							
								                    self._write_bio.write_eof()
							 | 
						||
| 
								 | 
							
								                    raise BrokenResourceError from exc
							 | 
						||
| 
								 | 
							
								                else:
							 | 
						||
| 
								 | 
							
								                    self._read_bio.write(data)
							 | 
						||
| 
								 | 
							
								            except ssl.SSLWantWriteError:
							 | 
						||
| 
								 | 
							
								                await self.transport_stream.send(self._write_bio.read())
							 | 
						||
| 
								 | 
							
								            except ssl.SSLSyscallError as exc:
							 | 
						||
| 
								 | 
							
								                self._read_bio.write_eof()
							 | 
						||
| 
								 | 
							
								                self._write_bio.write_eof()
							 | 
						||
| 
								 | 
							
								                raise BrokenResourceError from exc
							 | 
						||
| 
								 | 
							
								            except ssl.SSLError as exc:
							 | 
						||
| 
								 | 
							
								                self._read_bio.write_eof()
							 | 
						||
| 
								 | 
							
								                self._write_bio.write_eof()
							 | 
						||
| 
								 | 
							
								                if (isinstance(exc, ssl.SSLEOFError)
							 | 
						||
| 
								 | 
							
								                        or 'UNEXPECTED_EOF_WHILE_READING' in exc.strerror):
							 | 
						||
| 
								 | 
							
								                    if self.standard_compatible:
							 | 
						||
| 
								 | 
							
								                        raise BrokenResourceError from exc
							 | 
						||
| 
								 | 
							
								                    else:
							 | 
						||
| 
								 | 
							
								                        raise EndOfStream from None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                raise
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                # Flush any pending writes first
							 | 
						||
| 
								 | 
							
								                if self._write_bio.pending:
							 | 
						||
| 
								 | 
							
								                    await self.transport_stream.send(self._write_bio.read())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                return result
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def unwrap(self) -> Tuple[AnyByteStream, bytes]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Does the TLS closing handshake.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        await self._call_sslobject_method(self._ssl_object.unwrap)
							 | 
						||
| 
								 | 
							
								        self._read_bio.write_eof()
							 | 
						||
| 
								 | 
							
								        self._write_bio.write_eof()
							 | 
						||
| 
								 | 
							
								        return self.transport_stream, self._read_bio.read()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def aclose(self) -> None:
							 | 
						||
| 
								 | 
							
								        if self.standard_compatible:
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                await self.unwrap()
							 | 
						||
| 
								 | 
							
								            except BaseException:
							 | 
						||
| 
								 | 
							
								                await aclose_forcefully(self.transport_stream)
							 | 
						||
| 
								 | 
							
								                raise
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        await self.transport_stream.aclose()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def receive(self, max_bytes: int = 65536) -> bytes:
							 | 
						||
| 
								 | 
							
								        data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
							 | 
						||
| 
								 | 
							
								        if not data:
							 | 
						||
| 
								 | 
							
								            raise EndOfStream
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return data
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def send(self, item: bytes) -> None:
							 | 
						||
| 
								 | 
							
								        await self._call_sslobject_method(self._ssl_object.write, item)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def send_eof(self) -> None:
							 | 
						||
| 
								 | 
							
								        tls_version = self.extra(TLSAttribute.tls_version)
							 | 
						||
| 
								 | 
							
								        match = re.match(r'TLSv(\d+)(?:\.(\d+))?', tls_version)
							 | 
						||
| 
								 | 
							
								        if match:
							 | 
						||
| 
								 | 
							
								            major, minor = int(match.group(1)), int(match.group(2) or 0)
							 | 
						||
| 
								 | 
							
								            if (major, minor) < (1, 3):
							 | 
						||
| 
								 | 
							
								                raise NotImplementedError(f'send_eof() requires at least TLSv1.3; current '
							 | 
						||
| 
								 | 
							
								                                          f'session uses {tls_version}')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        raise NotImplementedError('send_eof() has not yet been implemented for TLS streams')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @property
							 | 
						||
| 
								 | 
							
								    def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
							 | 
						||
| 
								 | 
							
								        return {
							 | 
						||
| 
								 | 
							
								            **self.transport_stream.extra_attributes,
							 | 
						||
| 
								 | 
							
								            TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
							 | 
						||
| 
								 | 
							
								            TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding,
							 | 
						||
| 
								 | 
							
								            TLSAttribute.cipher: self._ssl_object.cipher,
							 | 
						||
| 
								 | 
							
								            TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
							 | 
						||
| 
								 | 
							
								            TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(True),
							 | 
						||
| 
								 | 
							
								            TLSAttribute.server_side: lambda: self._ssl_object.server_side,
							 | 
						||
| 
								 | 
							
								            TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers(),
							 | 
						||
| 
								 | 
							
								            TLSAttribute.standard_compatible: lambda: self.standard_compatible,
							 | 
						||
| 
								 | 
							
								            TLSAttribute.ssl_object: lambda: self._ssl_object,
							 | 
						||
| 
								 | 
							
								            TLSAttribute.tls_version: self._ssl_object.version
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@dataclass(eq=False)
							 | 
						||
| 
								 | 
							
								class TLSListener(Listener[TLSStream]):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    A convenience listener that wraps another listener and auto-negotiates a TLS session on every
							 | 
						||
| 
								 | 
							
								    accepted connection.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is
							 | 
						||
| 
								 | 
							
								    called to do whatever post-mortem processing is deemed necessary.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    :param Listener listener: the listener to wrap
							 | 
						||
| 
								 | 
							
								    :param ssl_context: the SSL context object
							 | 
						||
| 
								 | 
							
								    :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
							 | 
						||
| 
								 | 
							
								    :param handshake_timeout: time limit for the TLS handshake
							 | 
						||
| 
								 | 
							
								        (passed to :func:`~anyio.fail_after`)
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    listener: Listener[Any]
							 | 
						||
| 
								 | 
							
								    ssl_context: ssl.SSLContext
							 | 
						||
| 
								 | 
							
								    standard_compatible: bool = True
							 | 
						||
| 
								 | 
							
								    handshake_timeout: float = 30
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @staticmethod
							 | 
						||
| 
								 | 
							
								    async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
							 | 
						||
| 
								 | 
							
								        f"""
							 | 
						||
| 
								 | 
							
								        Handle an exception raised during the TLS handshake.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        This method does 3 things:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        #. Forcefully closes the original stream
							 | 
						||
| 
								 | 
							
								        #. Logs the exception (unless it was a cancellation exception) using the ``{__name__}``
							 | 
						||
| 
								 | 
							
								           logger
							 | 
						||
| 
								 | 
							
								        #. Reraises the exception if it was a base exception or a cancellation exception
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param exc: the exception
							 | 
						||
| 
								 | 
							
								        :param stream: the original stream
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        await aclose_forcefully(stream)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # Log all except cancellation exceptions
							 | 
						||
| 
								 | 
							
								        if not isinstance(exc, get_cancelled_exc_class()):
							 | 
						||
| 
								 | 
							
								            logging.getLogger(__name__).exception('Error during TLS handshake')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # Only reraise base exceptions and cancellation exceptions
							 | 
						||
| 
								 | 
							
								        if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
							 | 
						||
| 
								 | 
							
								            raise
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def serve(self, handler: Callable[[TLSStream], Any],
							 | 
						||
| 
								 | 
							
								                    task_group: Optional[TaskGroup] = None) -> None:
							 | 
						||
| 
								 | 
							
								        @wraps(handler)
							 | 
						||
| 
								 | 
							
								        async def handler_wrapper(stream: AnyByteStream) -> None:
							 | 
						||
| 
								 | 
							
								            from .. import fail_after
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                with fail_after(self.handshake_timeout):
							 | 
						||
| 
								 | 
							
								                    wrapped_stream = await TLSStream.wrap(
							 | 
						||
| 
								 | 
							
								                        stream, ssl_context=self.ssl_context,
							 | 
						||
| 
								 | 
							
								                        standard_compatible=self.standard_compatible)
							 | 
						||
| 
								 | 
							
								            except BaseException as exc:
							 | 
						||
| 
								 | 
							
								                await self.handle_handshake_error(exc, stream)
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                await handler(wrapped_stream)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        await self.listener.serve(handler_wrapper, task_group)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def aclose(self) -> None:
							 | 
						||
| 
								 | 
							
								        await self.listener.aclose()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @property
							 | 
						||
| 
								 | 
							
								    def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
							 | 
						||
| 
								 | 
							
								        return {
							 | 
						||
| 
								 | 
							
								            TLSAttribute.standard_compatible: lambda: self.standard_compatible,
							 | 
						||
| 
								 | 
							
								        }
							 |