You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					102 lines
				
				3.0 KiB
			
		
		
			
		
	
	
					102 lines
				
				3.0 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import dns.exception
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								# pylint: disable=unused-import
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from dns._asyncbackend import Socket, DatagramSocket, \
							 | 
						||
| 
								 | 
							
								    StreamSocket, Backend  # noqa:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								# pylint: enable=unused-import
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								_default_backend = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								_backends = {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								# Allow sniffio import to be disabled for testing purposes
							 | 
						||
| 
								 | 
							
								_no_sniffio = False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class AsyncLibraryNotFoundError(dns.exception.DNSException):
							 | 
						||
| 
								 | 
							
								    pass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def get_backend(name):
							 | 
						||
| 
								 | 
							
								    """Get the specified asynchronous backend.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    *name*, a ``str``, the name of the backend.  Currently the "trio",
							 | 
						||
| 
								 | 
							
								    "curio", and "asyncio" backends are available.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Raises NotImplementError if an unknown backend name is specified.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    # pylint: disable=import-outside-toplevel,redefined-outer-name
							 | 
						||
| 
								 | 
							
								    backend = _backends.get(name)
							 | 
						||
| 
								 | 
							
								    if backend:
							 | 
						||
| 
								 | 
							
								        return backend
							 | 
						||
| 
								 | 
							
								    if name == 'trio':
							 | 
						||
| 
								 | 
							
								        import dns._trio_backend
							 | 
						||
| 
								 | 
							
								        backend = dns._trio_backend.Backend()
							 | 
						||
| 
								 | 
							
								    elif name == 'curio':
							 | 
						||
| 
								 | 
							
								        import dns._curio_backend
							 | 
						||
| 
								 | 
							
								        backend = dns._curio_backend.Backend()
							 | 
						||
| 
								 | 
							
								    elif name == 'asyncio':
							 | 
						||
| 
								 | 
							
								        import dns._asyncio_backend
							 | 
						||
| 
								 | 
							
								        backend = dns._asyncio_backend.Backend()
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        raise NotImplementedError(f'unimplemented async backend {name}')
							 | 
						||
| 
								 | 
							
								    _backends[name] = backend
							 | 
						||
| 
								 | 
							
								    return backend
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def sniff():
							 | 
						||
| 
								 | 
							
								    """Attempt to determine the in-use asynchronous I/O library by using
							 | 
						||
| 
								 | 
							
								    the ``sniffio`` module if it is available.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Returns the name of the library, or raises AsyncLibraryNotFoundError
							 | 
						||
| 
								 | 
							
								    if the library cannot be determined.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    # pylint: disable=import-outside-toplevel
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        if _no_sniffio:
							 | 
						||
| 
								 | 
							
								            raise ImportError
							 | 
						||
| 
								 | 
							
								        import sniffio
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            return sniffio.current_async_library()
							 | 
						||
| 
								 | 
							
								        except sniffio.AsyncLibraryNotFoundError:
							 | 
						||
| 
								 | 
							
								            raise AsyncLibraryNotFoundError('sniffio cannot determine ' +
							 | 
						||
| 
								 | 
							
								                                            'async library')
							 | 
						||
| 
								 | 
							
								    except ImportError:
							 | 
						||
| 
								 | 
							
								        import asyncio
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            asyncio.get_running_loop()
							 | 
						||
| 
								 | 
							
								            return 'asyncio'
							 | 
						||
| 
								 | 
							
								        except RuntimeError:
							 | 
						||
| 
								 | 
							
								            raise AsyncLibraryNotFoundError('no async library detected')
							 | 
						||
| 
								 | 
							
								        except AttributeError:  # pragma: no cover
							 | 
						||
| 
								 | 
							
								            # we have to check current_task on 3.6
							 | 
						||
| 
								 | 
							
								            if not asyncio.Task.current_task():
							 | 
						||
| 
								 | 
							
								                raise AsyncLibraryNotFoundError('no async library detected')
							 | 
						||
| 
								 | 
							
								            return 'asyncio'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def get_default_backend():
							 | 
						||
| 
								 | 
							
								    """Get the default backend, initializing it if necessary.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    if _default_backend:
							 | 
						||
| 
								 | 
							
								        return _default_backend
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return set_default_backend(sniff())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def set_default_backend(name):
							 | 
						||
| 
								 | 
							
								    """Set the default backend.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    It's not normally necessary to call this method, as
							 | 
						||
| 
								 | 
							
								    ``get_default_backend()`` will initialize the backend
							 | 
						||
| 
								 | 
							
								    appropriately in many cases.  If ``sniffio`` is not installed, or
							 | 
						||
| 
								 | 
							
								    in testing situations, this function allows the backend to be set
							 | 
						||
| 
								 | 
							
								    explicitly.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    global _default_backend
							 | 
						||
| 
								 | 
							
								    _default_backend = get_backend(name)
							 | 
						||
| 
								 | 
							
								    return _default_backend
							 |