You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					151 lines
				
				4.9 KiB
			
		
		
			
		
	
	
					151 lines
				
				4.9 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import hashlib
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import ecdsa
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from jose.backends.base import Key
							 | 
						||
| 
								 | 
							
								from jose.constants import ALGORITHMS
							 | 
						||
| 
								 | 
							
								from jose.exceptions import JWKError
							 | 
						||
| 
								 | 
							
								from jose.utils import base64_to_long, long_to_base64
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class ECDSAECKey(Key):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Performs signing and verification operations using
							 | 
						||
| 
								 | 
							
								    ECDSA and the specified hash function
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    This class requires the ecdsa package to be installed.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    This is based off of the implementation in PyJWT 0.3.2
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    SHA256 = hashlib.sha256
							 | 
						||
| 
								 | 
							
								    SHA384 = hashlib.sha384
							 | 
						||
| 
								 | 
							
								    SHA512 = hashlib.sha512
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    CURVE_MAP = {
							 | 
						||
| 
								 | 
							
								        SHA256: ecdsa.curves.NIST256p,
							 | 
						||
| 
								 | 
							
								        SHA384: ecdsa.curves.NIST384p,
							 | 
						||
| 
								 | 
							
								        SHA512: ecdsa.curves.NIST521p,
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    CURVE_NAMES = (
							 | 
						||
| 
								 | 
							
								        (ecdsa.curves.NIST256p, "P-256"),
							 | 
						||
| 
								 | 
							
								        (ecdsa.curves.NIST384p, "P-384"),
							 | 
						||
| 
								 | 
							
								        (ecdsa.curves.NIST521p, "P-521"),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, key, algorithm):
							 | 
						||
| 
								 | 
							
								        if algorithm not in ALGORITHMS.EC:
							 | 
						||
| 
								 | 
							
								            raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.hash_alg = {
							 | 
						||
| 
								 | 
							
								            ALGORITHMS.ES256: self.SHA256,
							 | 
						||
| 
								 | 
							
								            ALGORITHMS.ES384: self.SHA384,
							 | 
						||
| 
								 | 
							
								            ALGORITHMS.ES512: self.SHA512,
							 | 
						||
| 
								 | 
							
								        }.get(algorithm)
							 | 
						||
| 
								 | 
							
								        self._algorithm = algorithm
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.curve = self.CURVE_MAP.get(self.hash_alg)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if isinstance(key, (ecdsa.SigningKey, ecdsa.VerifyingKey)):
							 | 
						||
| 
								 | 
							
								            self.prepared_key = key
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if isinstance(key, dict):
							 | 
						||
| 
								 | 
							
								            self.prepared_key = self._process_jwk(key)
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if isinstance(key, str):
							 | 
						||
| 
								 | 
							
								            key = key.encode("utf-8")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if isinstance(key, bytes):
							 | 
						||
| 
								 | 
							
								            # Attempt to load key. We don't know if it's
							 | 
						||
| 
								 | 
							
								            # a Signing Key or a Verifying Key, so we try
							 | 
						||
| 
								 | 
							
								            # the Verifying Key first.
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                key = ecdsa.VerifyingKey.from_pem(key)
							 | 
						||
| 
								 | 
							
								            except ecdsa.der.UnexpectedDER:
							 | 
						||
| 
								 | 
							
								                key = ecdsa.SigningKey.from_pem(key)
							 | 
						||
| 
								 | 
							
								            except Exception as e:
							 | 
						||
| 
								 | 
							
								                raise JWKError(e)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            self.prepared_key = key
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        raise JWKError("Unable to parse an ECKey from key: %s" % key)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _process_jwk(self, jwk_dict):
							 | 
						||
| 
								 | 
							
								        if not jwk_dict.get("kty") == "EC":
							 | 
						||
| 
								 | 
							
								            raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not all(k in jwk_dict for k in ["x", "y", "crv"]):
							 | 
						||
| 
								 | 
							
								            raise JWKError("Mandatory parameters are missing")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if "d" in jwk_dict:
							 | 
						||
| 
								 | 
							
								            # We are dealing with a private key; the secret exponent is enough
							 | 
						||
| 
								 | 
							
								            # to create an ecdsa key.
							 | 
						||
| 
								 | 
							
								            d = base64_to_long(jwk_dict.get("d"))
							 | 
						||
| 
								 | 
							
								            return ecdsa.keys.SigningKey.from_secret_exponent(d, self.curve)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            x = base64_to_long(jwk_dict.get("x"))
							 | 
						||
| 
								 | 
							
								            y = base64_to_long(jwk_dict.get("y"))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if not ecdsa.ecdsa.point_is_valid(self.curve.generator, x, y):
							 | 
						||
| 
								 | 
							
								                raise JWKError(f"Point: {x}, {y} is not a valid point")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            point = ecdsa.ellipticcurve.Point(self.curve.curve, x, y, self.curve.order)
							 | 
						||
| 
								 | 
							
								            return ecdsa.keys.VerifyingKey.from_public_point(point, self.curve)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def sign(self, msg):
							 | 
						||
| 
								 | 
							
								        return self.prepared_key.sign(
							 | 
						||
| 
								 | 
							
								            msg, hashfunc=self.hash_alg, sigencode=ecdsa.util.sigencode_string, allow_truncate=False
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def verify(self, msg, sig):
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            return self.prepared_key.verify(
							 | 
						||
| 
								 | 
							
								                sig, msg, hashfunc=self.hash_alg, sigdecode=ecdsa.util.sigdecode_string, allow_truncate=False
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        except Exception:
							 | 
						||
| 
								 | 
							
								            return False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def is_public(self):
							 | 
						||
| 
								 | 
							
								        return isinstance(self.prepared_key, ecdsa.VerifyingKey)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def public_key(self):
							 | 
						||
| 
								 | 
							
								        if self.is_public():
							 | 
						||
| 
								 | 
							
								            return self
							 | 
						||
| 
								 | 
							
								        return self.__class__(self.prepared_key.get_verifying_key(), self._algorithm)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def to_pem(self):
							 | 
						||
| 
								 | 
							
								        return self.prepared_key.to_pem()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def to_dict(self):
							 | 
						||
| 
								 | 
							
								        if not self.is_public():
							 | 
						||
| 
								 | 
							
								            public_key = self.prepared_key.get_verifying_key()
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            public_key = self.prepared_key
							 | 
						||
| 
								 | 
							
								        crv = None
							 | 
						||
| 
								 | 
							
								        for key, value in self.CURVE_NAMES:
							 | 
						||
| 
								 | 
							
								            if key == self.prepared_key.curve:
							 | 
						||
| 
								 | 
							
								                crv = value
							 | 
						||
| 
								 | 
							
								        if not crv:
							 | 
						||
| 
								 | 
							
								            raise KeyError(f"Can't match {self.prepared_key.curve}")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # Calculate the key size in bytes. Section 6.2.1.2 and 6.2.1.3 of
							 | 
						||
| 
								 | 
							
								        # RFC7518 prescribes that the 'x', 'y' and 'd' parameters of the curve
							 | 
						||
| 
								 | 
							
								        # points must be encoded as octed-strings of this length.
							 | 
						||
| 
								 | 
							
								        key_size = self.prepared_key.curve.baselen
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        data = {
							 | 
						||
| 
								 | 
							
								            "alg": self._algorithm,
							 | 
						||
| 
								 | 
							
								            "kty": "EC",
							 | 
						||
| 
								 | 
							
								            "crv": crv,
							 | 
						||
| 
								 | 
							
								            "x": long_to_base64(public_key.pubkey.point.x(), size=key_size).decode("ASCII"),
							 | 
						||
| 
								 | 
							
								            "y": long_to_base64(public_key.pubkey.point.y(), size=key_size).decode("ASCII"),
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not self.is_public():
							 | 
						||
| 
								 | 
							
								            data["d"] = long_to_base64(self.prepared_key.privkey.secret_multiplier, size=key_size).decode("ASCII")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return data
							 |