You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					150 lines
				
				4.8 KiB
			
		
		
			
		
	
	
					150 lines
				
				4.8 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								"""asyncio library query support"""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import socket
							 | 
						||
| 
								 | 
							
								import asyncio
							 | 
						||
| 
								 | 
							
								import sys
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import dns._asyncbackend
							 | 
						||
| 
								 | 
							
								import dns.exception
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								_is_win32 = sys.platform == 'win32'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def _get_running_loop():
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        return asyncio.get_running_loop()
							 | 
						||
| 
								 | 
							
								    except AttributeError:  # pragma: no cover
							 | 
						||
| 
								 | 
							
								        return asyncio.get_event_loop()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class _DatagramProtocol:
							 | 
						||
| 
								 | 
							
								    def __init__(self):
							 | 
						||
| 
								 | 
							
								        self.transport = None
							 | 
						||
| 
								 | 
							
								        self.recvfrom = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def connection_made(self, transport):
							 | 
						||
| 
								 | 
							
								        self.transport = transport
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def datagram_received(self, data, addr):
							 | 
						||
| 
								 | 
							
								        if self.recvfrom and not self.recvfrom.done():
							 | 
						||
| 
								 | 
							
								            self.recvfrom.set_result((data, addr))
							 | 
						||
| 
								 | 
							
								            self.recvfrom = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def error_received(self, exc):  # pragma: no cover
							 | 
						||
| 
								 | 
							
								        if self.recvfrom and not self.recvfrom.done():
							 | 
						||
| 
								 | 
							
								            self.recvfrom.set_exception(exc)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def connection_lost(self, exc):
							 | 
						||
| 
								 | 
							
								        if self.recvfrom and not self.recvfrom.done():
							 | 
						||
| 
								 | 
							
								            self.recvfrom.set_exception(exc)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def close(self):
							 | 
						||
| 
								 | 
							
								        self.transport.close()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								async def _maybe_wait_for(awaitable, timeout):
							 | 
						||
| 
								 | 
							
								    if timeout:
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            return await asyncio.wait_for(awaitable, timeout)
							 | 
						||
| 
								 | 
							
								        except asyncio.TimeoutError:
							 | 
						||
| 
								 | 
							
								            raise dns.exception.Timeout(timeout=timeout)
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        return await awaitable
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class DatagramSocket(dns._asyncbackend.DatagramSocket):
							 | 
						||
| 
								 | 
							
								    def __init__(self, family, transport, protocol):
							 | 
						||
| 
								 | 
							
								        self.family = family
							 | 
						||
| 
								 | 
							
								        self.transport = transport
							 | 
						||
| 
								 | 
							
								        self.protocol = protocol
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def sendto(self, what, destination, timeout):  # pragma: no cover
							 | 
						||
| 
								 | 
							
								        # no timeout for asyncio sendto
							 | 
						||
| 
								 | 
							
								        self.transport.sendto(what, destination)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def recvfrom(self, size, timeout):
							 | 
						||
| 
								 | 
							
								        # ignore size as there's no way I know to tell protocol about it
							 | 
						||
| 
								 | 
							
								        done = _get_running_loop().create_future()
							 | 
						||
| 
								 | 
							
								        assert self.protocol.recvfrom is None
							 | 
						||
| 
								 | 
							
								        self.protocol.recvfrom = done
							 | 
						||
| 
								 | 
							
								        await _maybe_wait_for(done, timeout)
							 | 
						||
| 
								 | 
							
								        return done.result()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def close(self):
							 | 
						||
| 
								 | 
							
								        self.protocol.close()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def getpeername(self):
							 | 
						||
| 
								 | 
							
								        return self.transport.get_extra_info('peername')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def getsockname(self):
							 | 
						||
| 
								 | 
							
								        return self.transport.get_extra_info('sockname')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class StreamSocket(dns._asyncbackend.StreamSocket):
							 | 
						||
| 
								 | 
							
								    def __init__(self, af, reader, writer):
							 | 
						||
| 
								 | 
							
								        self.family = af
							 | 
						||
| 
								 | 
							
								        self.reader = reader
							 | 
						||
| 
								 | 
							
								        self.writer = writer
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def sendall(self, what, timeout):
							 | 
						||
| 
								 | 
							
								        self.writer.write(what)
							 | 
						||
| 
								 | 
							
								        return await _maybe_wait_for(self.writer.drain(), timeout)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def recv(self, size, timeout):
							 | 
						||
| 
								 | 
							
								        return await _maybe_wait_for(self.reader.read(size),
							 | 
						||
| 
								 | 
							
								                                     timeout)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def close(self):
							 | 
						||
| 
								 | 
							
								        self.writer.close()
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            await self.writer.wait_closed()
							 | 
						||
| 
								 | 
							
								        except AttributeError:  # pragma: no cover
							 | 
						||
| 
								 | 
							
								            pass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def getpeername(self):
							 | 
						||
| 
								 | 
							
								        return self.writer.get_extra_info('peername')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def getsockname(self):
							 | 
						||
| 
								 | 
							
								        return self.writer.get_extra_info('sockname')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class Backend(dns._asyncbackend.Backend):
							 | 
						||
| 
								 | 
							
								    def name(self):
							 | 
						||
| 
								 | 
							
								        return 'asyncio'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def make_socket(self, af, socktype, proto=0,
							 | 
						||
| 
								 | 
							
								                          source=None, destination=None, timeout=None,
							 | 
						||
| 
								 | 
							
								                          ssl_context=None, server_hostname=None):
							 | 
						||
| 
								 | 
							
								        if destination is None and socktype == socket.SOCK_DGRAM and \
							 | 
						||
| 
								 | 
							
								           _is_win32:
							 | 
						||
| 
								 | 
							
								            raise NotImplementedError('destinationless datagram sockets '
							 | 
						||
| 
								 | 
							
								                                      'are not supported by asyncio '
							 | 
						||
| 
								 | 
							
								                                      'on Windows')
							 | 
						||
| 
								 | 
							
								        loop = _get_running_loop()
							 | 
						||
| 
								 | 
							
								        if socktype == socket.SOCK_DGRAM:
							 | 
						||
| 
								 | 
							
								            transport, protocol = await loop.create_datagram_endpoint(
							 | 
						||
| 
								 | 
							
								                _DatagramProtocol, source, family=af,
							 | 
						||
| 
								 | 
							
								                proto=proto, remote_addr=destination)
							 | 
						||
| 
								 | 
							
								            return DatagramSocket(af, transport, protocol)
							 | 
						||
| 
								 | 
							
								        elif socktype == socket.SOCK_STREAM:
							 | 
						||
| 
								 | 
							
								            (r, w) = await _maybe_wait_for(
							 | 
						||
| 
								 | 
							
								                asyncio.open_connection(destination[0],
							 | 
						||
| 
								 | 
							
								                                        destination[1],
							 | 
						||
| 
								 | 
							
								                                        ssl=ssl_context,
							 | 
						||
| 
								 | 
							
								                                        family=af,
							 | 
						||
| 
								 | 
							
								                                        proto=proto,
							 | 
						||
| 
								 | 
							
								                                        local_addr=source,
							 | 
						||
| 
								 | 
							
								                                        server_hostname=server_hostname),
							 | 
						||
| 
								 | 
							
								                timeout)
							 | 
						||
| 
								 | 
							
								            return StreamSocket(af, r, w)
							 | 
						||
| 
								 | 
							
								        raise NotImplementedError('unsupported socket ' +
							 | 
						||
| 
								 | 
							
								                                  f'type {socktype}')  # pragma: no cover
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def sleep(self, interval):
							 | 
						||
| 
								 | 
							
								        await asyncio.sleep(interval)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def datagram_connection_required(self):
							 | 
						||
| 
								 | 
							
								        return _is_win32
							 |