You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					157 lines
				
				5.5 KiB
			
		
		
			
		
	
	
					157 lines
				
				5.5 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import socket
							 | 
						||
| 
								 | 
							
								from abc import abstractmethod
							 | 
						||
| 
								 | 
							
								from io import IOBase
							 | 
						||
| 
								 | 
							
								from ipaddress import IPv4Address, IPv6Address
							 | 
						||
| 
								 | 
							
								from socket import AddressFamily
							 | 
						||
| 
								 | 
							
								from types import TracebackType
							 | 
						||
| 
								 | 
							
								from typing import (
							 | 
						||
| 
								 | 
							
								    Any, AsyncContextManager, Callable, Collection, Dict, List, Mapping, Optional, Tuple, Type,
							 | 
						||
| 
								 | 
							
								    TypeVar, Union)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from .._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute
							 | 
						||
| 
								 | 
							
								from ._streams import ByteStream, Listener, T_Stream, UnreliableObjectStream
							 | 
						||
| 
								 | 
							
								from ._tasks import TaskGroup
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								IPAddressType = Union[str, IPv4Address, IPv6Address]
							 | 
						||
| 
								 | 
							
								IPSockAddrType = Tuple[str, int]
							 | 
						||
| 
								 | 
							
								SockAddrType = Union[IPSockAddrType, str]
							 | 
						||
| 
								 | 
							
								UDPPacketType = Tuple[bytes, IPSockAddrType]
							 | 
						||
| 
								 | 
							
								T_Retval = TypeVar('T_Retval')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class _NullAsyncContextManager:
							 | 
						||
| 
								 | 
							
								    async def __aenter__(self) -> None:
							 | 
						||
| 
								 | 
							
								        pass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def __aexit__(self, exc_type: Optional[Type[BaseException]],
							 | 
						||
| 
								 | 
							
								                        exc_val: Optional[BaseException],
							 | 
						||
| 
								 | 
							
								                        exc_tb: Optional[TracebackType]) -> Optional[bool]:
							 | 
						||
| 
								 | 
							
								        return None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SocketAttribute(TypedAttributeSet):
							 | 
						||
| 
								 | 
							
								    #: the address family of the underlying socket
							 | 
						||
| 
								 | 
							
								    family: AddressFamily = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: the local socket address of the underlying socket
							 | 
						||
| 
								 | 
							
								    local_address: SockAddrType = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: for IP addresses, the local port the underlying socket is bound to
							 | 
						||
| 
								 | 
							
								    local_port: int = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: the underlying stdlib socket object
							 | 
						||
| 
								 | 
							
								    raw_socket: socket.socket = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: the remote address the underlying socket is connected to
							 | 
						||
| 
								 | 
							
								    remote_address: SockAddrType = typed_attribute()
							 | 
						||
| 
								 | 
							
								    #: for IP addresses, the remote port the underlying socket is connected to
							 | 
						||
| 
								 | 
							
								    remote_port: int = typed_attribute()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class _SocketProvider(TypedAttributeProvider):
							 | 
						||
| 
								 | 
							
								    @property
							 | 
						||
| 
								 | 
							
								    def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
							 | 
						||
| 
								 | 
							
								        from .._core._sockets import convert_ipv6_sockaddr as convert
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        attributes: Dict[Any, Callable[[], Any]] = {
							 | 
						||
| 
								 | 
							
								            SocketAttribute.family: lambda: self._raw_socket.family,
							 | 
						||
| 
								 | 
							
								            SocketAttribute.local_address: lambda: convert(self._raw_socket.getsockname()),
							 | 
						||
| 
								 | 
							
								            SocketAttribute.raw_socket: lambda: self._raw_socket
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            peername: Optional[Tuple[str, int]] = convert(self._raw_socket.getpeername())
							 | 
						||
| 
								 | 
							
								        except OSError:
							 | 
						||
| 
								 | 
							
								            peername = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # Provide the remote address for connected sockets
							 | 
						||
| 
								 | 
							
								        if peername is not None:
							 | 
						||
| 
								 | 
							
								            attributes[SocketAttribute.remote_address] = lambda: peername
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # Provide local and remote ports for IP based sockets
							 | 
						||
| 
								 | 
							
								        if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
							 | 
						||
| 
								 | 
							
								            attributes[SocketAttribute.local_port] = lambda: self._raw_socket.getsockname()[1]
							 | 
						||
| 
								 | 
							
								            if peername is not None:
							 | 
						||
| 
								 | 
							
								                remote_port = peername[1]
							 | 
						||
| 
								 | 
							
								                attributes[SocketAttribute.remote_port] = lambda: remote_port
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return attributes
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @property
							 | 
						||
| 
								 | 
							
								    @abstractmethod
							 | 
						||
| 
								 | 
							
								    def _raw_socket(self) -> socket.socket:
							 | 
						||
| 
								 | 
							
								        pass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SocketStream(ByteStream, _SocketProvider):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Transports bytes over a socket.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Supports all relevant extra attributes from :class:`~SocketAttribute`.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class UNIXSocketStream(SocketStream):
							 | 
						||
| 
								 | 
							
								    @abstractmethod
							 | 
						||
| 
								 | 
							
								    async def send_fds(self, message: bytes, fds: Collection[Union[int, IOBase]]) -> None:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Send file descriptors along with a message to the peer.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param message: a non-empty bytestring
							 | 
						||
| 
								 | 
							
								        :param fds: a collection of files (either numeric file descriptors or open file or socket
							 | 
						||
| 
								 | 
							
								            objects)
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @abstractmethod
							 | 
						||
| 
								 | 
							
								    async def receive_fds(self, msglen: int, maxfds: int) -> Tuple[bytes, List[int]]:
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Receive file descriptors along with a message from the peer.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        :param msglen: length of the message to expect from the peer
							 | 
						||
| 
								 | 
							
								        :param maxfds: maximum number of file descriptors to expect from the peer
							 | 
						||
| 
								 | 
							
								        :return: a tuple of (message, file descriptors)
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SocketListener(Listener[SocketStream], _SocketProvider):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Listens to incoming socket connections.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Supports all relevant extra attributes from :class:`~SocketAttribute`.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @abstractmethod
							 | 
						||
| 
								 | 
							
								    async def accept(self) -> SocketStream:
							 | 
						||
| 
								 | 
							
								        """Accept an incoming connection."""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def serve(self, handler: Callable[[T_Stream], Any],
							 | 
						||
| 
								 | 
							
								                    task_group: Optional[TaskGroup] = None) -> None:
							 | 
						||
| 
								 | 
							
								        from .. import create_task_group
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        context_manager: AsyncContextManager
							 | 
						||
| 
								 | 
							
								        if task_group is None:
							 | 
						||
| 
								 | 
							
								            task_group = context_manager = create_task_group()
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            # Can be replaced with AsyncExitStack once on py3.7+
							 | 
						||
| 
								 | 
							
								            context_manager = _NullAsyncContextManager()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        async with context_manager:
							 | 
						||
| 
								 | 
							
								            while True:
							 | 
						||
| 
								 | 
							
								                stream = await self.accept()
							 | 
						||
| 
								 | 
							
								                task_group.start_soon(handler, stream)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Represents an unconnected UDP socket.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Supports all relevant extra attributes from :class:`~SocketAttribute`.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def sendto(self, data: bytes, host: str, port: int) -> None:
							 | 
						||
| 
								 | 
							
								        """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port)))."""
							 | 
						||
| 
								 | 
							
								        return await self.send((data, (host, port)))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Represents an connected UDP socket.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Supports all relevant extra attributes from :class:`~SocketAttribute`.
							 | 
						||
| 
								 | 
							
								    """
							 |