You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							151 lines
						
					
					
						
							4.9 KiB
						
					
					
				
			
		
		
	
	
							151 lines
						
					
					
						
							4.9 KiB
						
					
					
				import hashlib
 | 
						|
 | 
						|
import ecdsa
 | 
						|
 | 
						|
from jose.backends.base import Key
 | 
						|
from jose.constants import ALGORITHMS
 | 
						|
from jose.exceptions import JWKError
 | 
						|
from jose.utils import base64_to_long, long_to_base64
 | 
						|
 | 
						|
 | 
						|
class ECDSAECKey(Key):
 | 
						|
    """
 | 
						|
    Performs signing and verification operations using
 | 
						|
    ECDSA and the specified hash function
 | 
						|
 | 
						|
    This class requires the ecdsa package to be installed.
 | 
						|
 | 
						|
    This is based off of the implementation in PyJWT 0.3.2
 | 
						|
    """
 | 
						|
 | 
						|
    SHA256 = hashlib.sha256
 | 
						|
    SHA384 = hashlib.sha384
 | 
						|
    SHA512 = hashlib.sha512
 | 
						|
 | 
						|
    CURVE_MAP = {
 | 
						|
        SHA256: ecdsa.curves.NIST256p,
 | 
						|
        SHA384: ecdsa.curves.NIST384p,
 | 
						|
        SHA512: ecdsa.curves.NIST521p,
 | 
						|
    }
 | 
						|
    CURVE_NAMES = (
 | 
						|
        (ecdsa.curves.NIST256p, "P-256"),
 | 
						|
        (ecdsa.curves.NIST384p, "P-384"),
 | 
						|
        (ecdsa.curves.NIST521p, "P-521"),
 | 
						|
    )
 | 
						|
 | 
						|
    def __init__(self, key, algorithm):
 | 
						|
        if algorithm not in ALGORITHMS.EC:
 | 
						|
            raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
 | 
						|
 | 
						|
        self.hash_alg = {
 | 
						|
            ALGORITHMS.ES256: self.SHA256,
 | 
						|
            ALGORITHMS.ES384: self.SHA384,
 | 
						|
            ALGORITHMS.ES512: self.SHA512,
 | 
						|
        }.get(algorithm)
 | 
						|
        self._algorithm = algorithm
 | 
						|
 | 
						|
        self.curve = self.CURVE_MAP.get(self.hash_alg)
 | 
						|
 | 
						|
        if isinstance(key, (ecdsa.SigningKey, ecdsa.VerifyingKey)):
 | 
						|
            self.prepared_key = key
 | 
						|
            return
 | 
						|
 | 
						|
        if isinstance(key, dict):
 | 
						|
            self.prepared_key = self._process_jwk(key)
 | 
						|
            return
 | 
						|
 | 
						|
        if isinstance(key, str):
 | 
						|
            key = key.encode("utf-8")
 | 
						|
 | 
						|
        if isinstance(key, bytes):
 | 
						|
            # Attempt to load key. We don't know if it's
 | 
						|
            # a Signing Key or a Verifying Key, so we try
 | 
						|
            # the Verifying Key first.
 | 
						|
            try:
 | 
						|
                key = ecdsa.VerifyingKey.from_pem(key)
 | 
						|
            except ecdsa.der.UnexpectedDER:
 | 
						|
                key = ecdsa.SigningKey.from_pem(key)
 | 
						|
            except Exception as e:
 | 
						|
                raise JWKError(e)
 | 
						|
 | 
						|
            self.prepared_key = key
 | 
						|
            return
 | 
						|
 | 
						|
        raise JWKError("Unable to parse an ECKey from key: %s" % key)
 | 
						|
 | 
						|
    def _process_jwk(self, jwk_dict):
 | 
						|
        if not jwk_dict.get("kty") == "EC":
 | 
						|
            raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
 | 
						|
 | 
						|
        if not all(k in jwk_dict for k in ["x", "y", "crv"]):
 | 
						|
            raise JWKError("Mandatory parameters are missing")
 | 
						|
 | 
						|
        if "d" in jwk_dict:
 | 
						|
            # We are dealing with a private key; the secret exponent is enough
 | 
						|
            # to create an ecdsa key.
 | 
						|
            d = base64_to_long(jwk_dict.get("d"))
 | 
						|
            return ecdsa.keys.SigningKey.from_secret_exponent(d, self.curve)
 | 
						|
        else:
 | 
						|
            x = base64_to_long(jwk_dict.get("x"))
 | 
						|
            y = base64_to_long(jwk_dict.get("y"))
 | 
						|
 | 
						|
            if not ecdsa.ecdsa.point_is_valid(self.curve.generator, x, y):
 | 
						|
                raise JWKError(f"Point: {x}, {y} is not a valid point")
 | 
						|
 | 
						|
            point = ecdsa.ellipticcurve.Point(self.curve.curve, x, y, self.curve.order)
 | 
						|
            return ecdsa.keys.VerifyingKey.from_public_point(point, self.curve)
 | 
						|
 | 
						|
    def sign(self, msg):
 | 
						|
        return self.prepared_key.sign(
 | 
						|
            msg, hashfunc=self.hash_alg, sigencode=ecdsa.util.sigencode_string, allow_truncate=False
 | 
						|
        )
 | 
						|
 | 
						|
    def verify(self, msg, sig):
 | 
						|
        try:
 | 
						|
            return self.prepared_key.verify(
 | 
						|
                sig, msg, hashfunc=self.hash_alg, sigdecode=ecdsa.util.sigdecode_string, allow_truncate=False
 | 
						|
            )
 | 
						|
        except Exception:
 | 
						|
            return False
 | 
						|
 | 
						|
    def is_public(self):
 | 
						|
        return isinstance(self.prepared_key, ecdsa.VerifyingKey)
 | 
						|
 | 
						|
    def public_key(self):
 | 
						|
        if self.is_public():
 | 
						|
            return self
 | 
						|
        return self.__class__(self.prepared_key.get_verifying_key(), self._algorithm)
 | 
						|
 | 
						|
    def to_pem(self):
 | 
						|
        return self.prepared_key.to_pem()
 | 
						|
 | 
						|
    def to_dict(self):
 | 
						|
        if not self.is_public():
 | 
						|
            public_key = self.prepared_key.get_verifying_key()
 | 
						|
        else:
 | 
						|
            public_key = self.prepared_key
 | 
						|
        crv = None
 | 
						|
        for key, value in self.CURVE_NAMES:
 | 
						|
            if key == self.prepared_key.curve:
 | 
						|
                crv = value
 | 
						|
        if not crv:
 | 
						|
            raise KeyError(f"Can't match {self.prepared_key.curve}")
 | 
						|
 | 
						|
        # Calculate the key size in bytes. Section 6.2.1.2 and 6.2.1.3 of
 | 
						|
        # RFC7518 prescribes that the 'x', 'y' and 'd' parameters of the curve
 | 
						|
        # points must be encoded as octed-strings of this length.
 | 
						|
        key_size = self.prepared_key.curve.baselen
 | 
						|
 | 
						|
        data = {
 | 
						|
            "alg": self._algorithm,
 | 
						|
            "kty": "EC",
 | 
						|
            "crv": crv,
 | 
						|
            "x": long_to_base64(public_key.pubkey.point.x(), size=key_size).decode("ASCII"),
 | 
						|
            "y": long_to_base64(public_key.pubkey.point.y(), size=key_size).decode("ASCII"),
 | 
						|
        }
 | 
						|
 | 
						|
        if not self.is_public():
 | 
						|
            data["d"] = long_to_base64(self.prepared_key.privkey.secret_multiplier, size=key_size).decode("ASCII")
 | 
						|
 | 
						|
        return data
 |