You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					113 lines
				
				2.9 KiB
			
		
		
			
		
	
	
					113 lines
				
				2.9 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import re
							 | 
						||
| 
								 | 
							
								import sys
							 | 
						||
| 
								 | 
							
								import traceback
							 | 
						||
| 
								 | 
							
								from typing import NoReturn
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import pytest
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from .._util import (
							 | 
						||
| 
								 | 
							
								    bytesify,
							 | 
						||
| 
								 | 
							
								    LocalProtocolError,
							 | 
						||
| 
								 | 
							
								    ProtocolError,
							 | 
						||
| 
								 | 
							
								    RemoteProtocolError,
							 | 
						||
| 
								 | 
							
								    Sentinel,
							 | 
						||
| 
								 | 
							
								    validate,
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_ProtocolError() -> None:
							 | 
						||
| 
								 | 
							
								    with pytest.raises(TypeError):
							 | 
						||
| 
								 | 
							
								        ProtocolError("abstract base class")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_LocalProtocolError() -> None:
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        raise LocalProtocolError("foo")
							 | 
						||
| 
								 | 
							
								    except LocalProtocolError as e:
							 | 
						||
| 
								 | 
							
								        assert str(e) == "foo"
							 | 
						||
| 
								 | 
							
								        assert e.error_status_hint == 400
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        raise LocalProtocolError("foo", error_status_hint=418)
							 | 
						||
| 
								 | 
							
								    except LocalProtocolError as e:
							 | 
						||
| 
								 | 
							
								        assert str(e) == "foo"
							 | 
						||
| 
								 | 
							
								        assert e.error_status_hint == 418
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def thunk() -> NoReturn:
							 | 
						||
| 
								 | 
							
								        raise LocalProtocolError("a", error_status_hint=420)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            thunk()
							 | 
						||
| 
								 | 
							
								        except LocalProtocolError as exc1:
							 | 
						||
| 
								 | 
							
								            orig_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
							 | 
						||
| 
								 | 
							
								            exc1._reraise_as_remote_protocol_error()
							 | 
						||
| 
								 | 
							
								    except RemoteProtocolError as exc2:
							 | 
						||
| 
								 | 
							
								        assert type(exc2) is RemoteProtocolError
							 | 
						||
| 
								 | 
							
								        assert exc2.args == ("a",)
							 | 
						||
| 
								 | 
							
								        assert exc2.error_status_hint == 420
							 | 
						||
| 
								 | 
							
								        new_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
							 | 
						||
| 
								 | 
							
								        assert new_traceback.endswith(orig_traceback)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_validate() -> None:
							 | 
						||
| 
								 | 
							
								    my_re = re.compile(br"(?P<group1>[0-9]+)\.(?P<group2>[0-9]+)")
							 | 
						||
| 
								 | 
							
								    with pytest.raises(LocalProtocolError):
							 | 
						||
| 
								 | 
							
								        validate(my_re, b"0.")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    groups = validate(my_re, b"0.1")
							 | 
						||
| 
								 | 
							
								    assert groups == {"group1": b"0", "group2": b"1"}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # successful partial matches are an error - must match whole string
							 | 
						||
| 
								 | 
							
								    with pytest.raises(LocalProtocolError):
							 | 
						||
| 
								 | 
							
								        validate(my_re, b"0.1xx")
							 | 
						||
| 
								 | 
							
								    with pytest.raises(LocalProtocolError):
							 | 
						||
| 
								 | 
							
								        validate(my_re, b"0.1\n")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_validate_formatting() -> None:
							 | 
						||
| 
								 | 
							
								    my_re = re.compile(br"foo")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    with pytest.raises(LocalProtocolError) as excinfo:
							 | 
						||
| 
								 | 
							
								        validate(my_re, b"", "oops")
							 | 
						||
| 
								 | 
							
								    assert "oops" in str(excinfo.value)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    with pytest.raises(LocalProtocolError) as excinfo:
							 | 
						||
| 
								 | 
							
								        validate(my_re, b"", "oops {}")
							 | 
						||
| 
								 | 
							
								    assert "oops {}" in str(excinfo.value)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    with pytest.raises(LocalProtocolError) as excinfo:
							 | 
						||
| 
								 | 
							
								        validate(my_re, b"", "oops {} xx", 10)
							 | 
						||
| 
								 | 
							
								    assert "oops 10 xx" in str(excinfo.value)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_make_sentinel() -> None:
							 | 
						||
| 
								 | 
							
								    class S(Sentinel, metaclass=Sentinel):
							 | 
						||
| 
								 | 
							
								        pass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    assert repr(S) == "S"
							 | 
						||
| 
								 | 
							
								    assert S == S
							 | 
						||
| 
								 | 
							
								    assert type(S).__name__ == "S"
							 | 
						||
| 
								 | 
							
								    assert S in {S}
							 | 
						||
| 
								 | 
							
								    assert type(S) is S
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    class S2(Sentinel, metaclass=Sentinel):
							 | 
						||
| 
								 | 
							
								        pass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    assert repr(S2) == "S2"
							 | 
						||
| 
								 | 
							
								    assert S != S2
							 | 
						||
| 
								 | 
							
								    assert S not in {S2}
							 | 
						||
| 
								 | 
							
								    assert type(S) is not type(S2)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_bytesify() -> None:
							 | 
						||
| 
								 | 
							
								    assert bytesify(b"123") == b"123"
							 | 
						||
| 
								 | 
							
								    assert bytesify(bytearray(b"123")) == b"123"
							 | 
						||
| 
								 | 
							
								    assert bytesify("123") == b"123"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    with pytest.raises(UnicodeEncodeError):
							 | 
						||
| 
								 | 
							
								        bytesify("\u1234")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    with pytest.raises(TypeError):
							 | 
						||
| 
								 | 
							
								        bytesify(10)
							 |