You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					122 lines
				
				3.7 KiB
			
		
		
			
		
	
	
					122 lines
				
				3.7 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								"""trio async I/O library query support"""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import socket
							 | 
						||
| 
								 | 
							
								import trio
							 | 
						||
| 
								 | 
							
								import trio.socket  # type: ignore
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import dns._asyncbackend
							 | 
						||
| 
								 | 
							
								import dns.exception
							 | 
						||
| 
								 | 
							
								import dns.inet
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def _maybe_timeout(timeout):
							 | 
						||
| 
								 | 
							
								    if timeout:
							 | 
						||
| 
								 | 
							
								        return trio.move_on_after(timeout)
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        return dns._asyncbackend.NullContext()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								# for brevity
							 | 
						||
| 
								 | 
							
								_lltuple = dns.inet.low_level_address_tuple
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								# pylint: disable=redefined-outer-name
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class DatagramSocket(dns._asyncbackend.DatagramSocket):
							 | 
						||
| 
								 | 
							
								    def __init__(self, socket):
							 | 
						||
| 
								 | 
							
								        self.socket = socket
							 | 
						||
| 
								 | 
							
								        self.family = socket.family
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def sendto(self, what, destination, timeout):
							 | 
						||
| 
								 | 
							
								        with _maybe_timeout(timeout):
							 | 
						||
| 
								 | 
							
								            return await self.socket.sendto(what, destination)
							 | 
						||
| 
								 | 
							
								        raise dns.exception.Timeout(timeout=timeout)  # pragma: no cover
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def recvfrom(self, size, timeout):
							 | 
						||
| 
								 | 
							
								        with _maybe_timeout(timeout):
							 | 
						||
| 
								 | 
							
								            return await self.socket.recvfrom(size)
							 | 
						||
| 
								 | 
							
								        raise dns.exception.Timeout(timeout=timeout)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def close(self):
							 | 
						||
| 
								 | 
							
								        self.socket.close()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def getpeername(self):
							 | 
						||
| 
								 | 
							
								        return self.socket.getpeername()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def getsockname(self):
							 | 
						||
| 
								 | 
							
								        return self.socket.getsockname()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class StreamSocket(dns._asyncbackend.StreamSocket):
							 | 
						||
| 
								 | 
							
								    def __init__(self, family, stream, tls=False):
							 | 
						||
| 
								 | 
							
								        self.family = family
							 | 
						||
| 
								 | 
							
								        self.stream = stream
							 | 
						||
| 
								 | 
							
								        self.tls = tls
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def sendall(self, what, timeout):
							 | 
						||
| 
								 | 
							
								        with _maybe_timeout(timeout):
							 | 
						||
| 
								 | 
							
								            return await self.stream.send_all(what)
							 | 
						||
| 
								 | 
							
								        raise dns.exception.Timeout(timeout=timeout)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def recv(self, size, timeout):
							 | 
						||
| 
								 | 
							
								        with _maybe_timeout(timeout):
							 | 
						||
| 
								 | 
							
								            return await self.stream.receive_some(size)
							 | 
						||
| 
								 | 
							
								        raise dns.exception.Timeout(timeout=timeout)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def close(self):
							 | 
						||
| 
								 | 
							
								        await self.stream.aclose()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def getpeername(self):
							 | 
						||
| 
								 | 
							
								        if self.tls:
							 | 
						||
| 
								 | 
							
								            return self.stream.transport_stream.socket.getpeername()
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            return self.stream.socket.getpeername()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def getsockname(self):
							 | 
						||
| 
								 | 
							
								        if self.tls:
							 | 
						||
| 
								 | 
							
								            return self.stream.transport_stream.socket.getsockname()
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            return self.stream.socket.getsockname()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class Backend(dns._asyncbackend.Backend):
							 | 
						||
| 
								 | 
							
								    def name(self):
							 | 
						||
| 
								 | 
							
								        return 'trio'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def make_socket(self, af, socktype, proto=0, source=None,
							 | 
						||
| 
								 | 
							
								                          destination=None, timeout=None,
							 | 
						||
| 
								 | 
							
								                          ssl_context=None, server_hostname=None):
							 | 
						||
| 
								 | 
							
								        s = trio.socket.socket(af, socktype, proto)
							 | 
						||
| 
								 | 
							
								        stream = None
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            if source:
							 | 
						||
| 
								 | 
							
								                await s.bind(_lltuple(source, af))
							 | 
						||
| 
								 | 
							
								            if socktype == socket.SOCK_STREAM:
							 | 
						||
| 
								 | 
							
								                with _maybe_timeout(timeout):
							 | 
						||
| 
								 | 
							
								                    await s.connect(_lltuple(destination, af))
							 | 
						||
| 
								 | 
							
								        except Exception:  # pragma: no cover
							 | 
						||
| 
								 | 
							
								            s.close()
							 | 
						||
| 
								 | 
							
								            raise
							 | 
						||
| 
								 | 
							
								        if socktype == socket.SOCK_DGRAM:
							 | 
						||
| 
								 | 
							
								            return DatagramSocket(s)
							 | 
						||
| 
								 | 
							
								        elif socktype == socket.SOCK_STREAM:
							 | 
						||
| 
								 | 
							
								            stream = trio.SocketStream(s)
							 | 
						||
| 
								 | 
							
								            s = None
							 | 
						||
| 
								 | 
							
								            tls = False
							 | 
						||
| 
								 | 
							
								            if ssl_context:
							 | 
						||
| 
								 | 
							
								                tls = True
							 | 
						||
| 
								 | 
							
								                try:
							 | 
						||
| 
								 | 
							
								                    stream = trio.SSLStream(stream, ssl_context,
							 | 
						||
| 
								 | 
							
								                                            server_hostname=server_hostname)
							 | 
						||
| 
								 | 
							
								                except Exception:  # pragma: no cover
							 | 
						||
| 
								 | 
							
								                    await stream.aclose()
							 | 
						||
| 
								 | 
							
								                    raise
							 | 
						||
| 
								 | 
							
								            return StreamSocket(af, stream, tls)
							 | 
						||
| 
								 | 
							
								        raise NotImplementedError('unsupported socket ' +
							 | 
						||
| 
								 | 
							
								                                  f'type {socktype}')    # pragma: no cover
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    async def sleep(self, interval):
							 | 
						||
| 
								 | 
							
								        await trio.sleep(interval)
							 |